Ursprünglich wurden Diffusionsmodelle in der Informatik hauptsächlich verwendet um das Rauschen aus Bildern zu entfernen. Sie können aber auch als generative KI genutzt werden um Bilder zu erzeugen. Dafür wird ein neuronales Netzwerk darauf trainiert, den Prozess des Hinzufügens von Rauschen zu einem Bild umzukehren. Nach dem Training kann das Model dann für die Bildgenerierung verwendet werden, indem mit einem Bild begonnen wird, das aus zufälligem Rauschen besteht.
Diffusionsbasierte Bildgeneratoren haben ein breites kommerzielles Interesse gefunden. Beispiele hierfür sind Stable Diffusion, DALL-E oder auch WAN AI das auf open source code basiert. Diese Modelle kombinieren in der Regel Diffusionsmodelle mit anderen Modellen, z. B. Textencodern um eine textbedingte Generierung zu ermöglichen. Das untenstehende Bild wurde z.B. mit WAN2.1 aus dem Prompt: 'Ein Astronaut der ein Pferd auf dem Mond reitet' erzeugt.
Im Wesentlichen funktioniert der Trainingsprozess eines solchen Denoising Diffusion Probabalistic Models (DDPM) wie folgt. Einem Bild bzw. einem Satz von Bildern wird normal verteiltes Rauschen zugefügt. Anschließend wird ein Transformer Netzwerk oder ein faltendes Netzwerk (CNN) darauf trainiert, das Rauschen zu entfernen. Schritt für Schritt wird dem Modell weiter Rauschen zugefügt, und das Modell dadurch verfeinert. Der Prozess setzt sich fort, bis das Bild vollständig verrauscht ist.
Was einfach klingt, erweist sich im Detail doch schwierig. Würden wir einem Bildvektor \(x_0 \) lediglich normal verteiltes Rauschen \(\epsilon = \mathcal{N}(0, 1)\) mit Erwartungswert \(\mu=0\) und Standardabweichung \(\sigma^2=1\) der Stärke \(\beta\) gemäßimport numpy as np
# Zufallswerte generieren x_0=norm.rvs(loc=0, scale=1, size=1000) epsilon = norm.rvs(loc=0, scale=1, size=1000)
beta =10 x_1=x_0+beta*epsilon
# Standardabweichung und Varianz berechnen sigma = np.std(x_1, ddof=0) # Standardabweichung der Stichprobe, „degrees of freedom“ = 1, sonst ddof=0 print(f"Standardabweichung (σ): {sigma:.4f}") varianz = np.var(x_1) print(f"Varianz: {varianz:.2f}")
\[q(x_n \mid x_0) \xrightarrow{n \to \infty} \mathcal{N}(0, 1)\]
$$x_n = (1 - \beta) \cdot x_{n-1} + \beta \cdot \epsilon$$
Wie eine einfache Rechnung zeigt ergibt sich \(x_n\) in Abhängigkeit von \(x_0\) dann zu
$$x_n = \alpha^n \cdot x_0 + (1 - \alpha^n) \cdot \epsilon$$ mit \(\alpha_n=1-\beta\). Wird \(\beta\) zwischen 0 und 1 gewählt, geht \(\alpha\) und damit der Erwartungswert von \(x_n\) gegen 0 und die Standardabweichung gegen 1.
Durch Verwendung eines 'noise schedules' aus N Schritten können auch für jeden Iterationsschritt unterschiedliche \(\alpha\) -Werte gewählt werden. Für \(x_N\) gilt dann ganz allgemein:
\[x_N = \prod_{i=1}^{N} \alpha_i \cdot x_0 + (1-\prod_{i=1}^{N}\alpha_i )\epsilon\]
Da die Standardabweichung die Wurzel der Varianz darstellt, wird in vielen Anwendungsbeispielen ein wurzelförmiger 'schedule' gewählt. Der Ausdruck \((1-\prod_{i=1}^{N}\alpha_i )\) steht dann für die Entwicklung der Varianz und nicht der Standardabweichung.
In meinem DDPM Bespiel habe ich das 'nn' package von PyTorch verwendet. Die 'add noise' Funktion ist wie folgt definiert.
def add_noise(x, t, noise_schedule): noise = torch.randn_like(x) # creates a new tensor with the same shape, data type, layout, and device as the input tensor x, but fills it with random values sampled from a standard normal distribution (mean = 0, standard deviation = 1) alpha = noise_schedule[t].sqrt().reshape(-1, 1, 1, 1) #reshape into 1 'column' of 3 dimensional arrays return alpha * x + (1 - alpha) * noise
Um das Rauschen nach jedem Iterationsschritt zu entfernen, habe ich ein einfaches Convolutional Neural Netweork (CNN) programmiert.
class SimpleDDPM(nn.Module): def __init__(self): super().__init__() self.model = nn.Sequential( nn.Conv2d(1, 64, 3, padding=1), # applies 64 filters (each of size 3×3), producing 64 feature maps. nn.Conv2d(64, 1, 3, padding=1) ) def forward(self, x, t): return self.model(x)
Die Conv2d-Schicht in torch legt keine konkreten Filter wie „Kantenerkennung“ oder „Mustererkennung“ fest – sie lernt die Filterparameter von hier 64 Filtern während des Trainings. Das neuronale Netzwerk versucht also das rauschfreie Bild zu erzeugen. In meinem Anwendungsbeispiel soll das neuronale Netzwerk mit dem MNIST Datensatz trainiert werden, um später handschriftliche Ziffern erzeugen zu können.
Ich habe das gesamte Programm als Kaggle notebook angelegt. Nach dem Einlesen der notwendigen Bibliotheken, werden einige globale Variablen festgelegt, und der MNIST Datensatz geladen. Um Rechenzeit zu sparen skaliere ich die Bilder auf eine Größe von 14x14 Pixel und lade nur 3000 Bilder.
import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms from torch.utils.data import DataLoader, Subset import matplotlib.pyplot as plt from tqdm import tqdm # define hyperparameter pic_size = 14 # size of pictures train_size = 3000 # size of dataset # noiseschedule timesteps = 50 # number of timesteps noise_schedule = torch.linspace(1e-4, 1, timesteps) #e.g. 100 steps in between 1e-4 and 1 # training epochNr=100 # numer of iterations lrate=1e-5 # learning rate # load MNIST dataset transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((pic_size, pic_size))]) #transform to tensor and downscale to e.g. 14x14 mnist = datasets.MNIST(root="./data", train=True, download=True, transform=transform) subset_indices = list(range(train_size)) #load e.g. 3000 pictures only subset = Subset(mnist, subset_indices) data_loader = DataLoader(subset, batch_size=64, shuffle=True) #data will be loaded in batches of 64 samples # get some samples of the dataset images, labels = next(iter(data_loader)) # and plot the samples fig, axes = plt.subplots(1, 6, figsize=(10, 2)) for i in range(6): axes[i].imshow(images[i][0], cmap="gray") axes[i].set_title(f"Label: {labels[i].item()}") axes[i].axis("off") plt.tight_layout() plt.show()
Die reduzierten Bilder sehen dann exemplarisch wie folgt aus.
Nach der bereits bekannten Definition der 'add_noise' Funktion und des neuronalen Netzwerks selbst, habe ich eine Funktion eingefügt, die es erlaubt den Vorgang des Entfernens von Rauschen bildlich zu verfolgen.
# define forward diffusion (adding noise) def add_noise(x, t, noise_schedule): noise = torch.randn_like(x) # create a new tensor with the same shape, data type, layout, and device as the input tensor x, and fill it with random values sampled from a standard normal distribution (mean = 0, standard deviation = 1) alpha = noise_schedule[t].sqrt().reshape(-1, 1, 1, 1) # reshape into 1 'column' of 3 dimensional arrays return alpha * x + (1 - alpha) * noise # define the diffusion model (CNN) class SimpleDDPM(nn.Module): def __init__(self): super().__init__() self.model = nn.Sequential( nn.Conv2d(1, 64, 3, padding=1), #appla 64 filters (each of size 3×3), producing 64 feature maps nn.Conv2d(64, 1, 3, padding=1) ) def forward(self, x, t): return self.model(x) # visualize denoising over time def visualize_denoising(model, x, timesteps_to_show, noise_schedule): model.eval() # switch to eval mode with torch.no_grad(): fig, axes = plt.subplots(1, len(timesteps_to_show), figsize=(15, 3)) for i, t_val in enumerate(timesteps_to_show): t = torch.full((x.size(0),), t_val, dtype=torch.long) noisy_x = add_noise(x, t, noise_schedule) # add noise pred_x = model(noisy_x, t) # model predicts denoised image axes[i].imshow(pred_x[0][0].cpu(), cmap="gray") axes[i].set_title(f"Timestep: {t_val}") axes[i].axis("off") plt.tight_layout() plt.show() model.train() # switch back to training mode
Die Sequenz unten zeigt in der ersten Zeile wie das Rauschen nach 20 Iterationen in Zeitschritten von 0 bis 50 aus einem ausgewählten Bild entfernt wird. Die zweite Zeile zeigt die gleichen Zeitschritte nach 40 Iterationen, die dritte nach 60 Iterationen, und die letzte Zeile nach Beendigung des Trainings von 100 Iterationen
Im Hauptteil des Programms wird das Training ausgeführt, und die 'loss' Funktion geplottet.
# train the model model = SimpleDDPM() optimizer = torch.optim.Adam(model.parameters(), lr=lrate) #adam optimization algorithm at defined learning rate epochs = epochNr losses = [] # array for loss values for epoch in range(epochs): for x, _ in tqdm(data_loader): x = x t = torch.randint(0, timesteps, (x.size(0),)) noisy_x = add_noise(x, t, noise_schedule) noise_pred = model(noisy_x, t) loss = F.mse_loss(noise_pred, x) optimizer.zero_grad() loss.backward() optimizer.step() losses.append(loss.item()) # store loss from this epoch if (epoch + 1) % 20 == 0: # show denoising for every 20th epoch print(f"Epoch {epoch+1} Loss: {loss.item():.4f}") # pick one image from your dataset sample_x, _ = next(iter(data_loader)) sample_x = sample_x[:1] # use a single image timesteps_to_show = [0, timesteps//4, timesteps//2, 3*timesteps//4, timesteps-1] # show 5 timesteps visualize_denoising(model, sample_x, timesteps_to_show, noise_schedule) plt.plot(range(1, epochs+1), losses, marker='o') plt.xlabel("Epoch") plt.ylabel("Loss") plt.title("Training Loss over Epochs") plt.grid(True) plt.show()
Das Bild unten zeigt die loss Kurve exemplarisch nach 200 Iterationen bei einer Lernrate von 1e-5.
Um nun ein Bild einer Ziffer generieren zu können muss das Netzerk erneut durchlaufen werden.
# define how to generate images from pure noise using a trained DDPM model. def sample_images_from_ddpm(model, noise_schedule, pic_size=14, timesteps=50, num_samples=6): """ Args: model: Trained DDPM model. noise_schedule: Schedule of noise levels (torch tensor). pic_size: Size of generated images (e.g. 14). timesteps: Number of timesteps used in training/sampling. num_samples: Number of images to generate. Returns: A tensor of shape (num_samples, 1, pic_size, pic_size) with generated images. """ model.eval() sample = torch.randn(num_samples, 1, pic_size, pic_size) with torch.no_grad(): for t in range(timesteps): t_tensor = torch.full((num_samples,), t, dtype=torch.long) pred = model(sample, t_tensor) alpha = noise_schedule[t] # alpha = noise_schedule[t].sqrt() alpha = alpha.view(1, 1, 1, 1) # make dimensions broadcastable noise = 7*torch.randn_like(pred) #sigma = 2, mu = 0 sample = alpha * pred + (1 - alpha) * noise # reverse diffusion step return sample # Generate 6 samples samples = sample_images_from_ddpm(model, noise_schedule, pic_size, timesteps, num_samples=6) # Plot the results fig, axes = plt.subplots(1, 6, figsize=(12, 2)) for i in range(6): axes[i].imshow(samples[i][0].cpu(),cmap='gray') axes[i].axis("off") plt.tight_layout() plt.show()
Bei einem sigma Wert von 2 werden aus einem rein zufällig verteilten Rauschen Bilder erzeugt, die zumindest einer Überlagerung von einzelnen trainierten Ziffern ähneln.
Keine Kommentare:
Kommentar veröffentlichen