Freitag, 3. Oktober 2025

Erster Versuch ein Denoising Diffusion Probabalistic Model zu programmieren

Ursprünglich wurden Diffusionsmodelle in der Informatik hauptsächlich verwendet um das Rauschen aus Bildern zu entfernen. Sie können aber auch als generative KI genutzt werden um Bilder zu erzeugen. Dafür wird ein neuronales Netzwerk darauf trainiert, den Prozess des Hinzufügens von Rauschen zu einem Bild umzukehren. Nach dem Training  kann das Model dann für die Bildgenerierung verwendet werden, indem mit einem Bild begonnen wird, das aus zufälligem Rauschen besteht.

Diffusionsbasierte Bildgeneratoren haben ein breites kommerzielles Interesse gefunden. Beispiele hierfür sind Stable DiffusionDALL-E oder auch WAN AI das auf open source code basiert. Diese Modelle kombinieren in der Regel Diffusionsmodelle mit anderen Modellen, z. B. Textencodern um eine textbedingte Generierung zu ermöglichen. Das untenstehende Bild wurde z.B. mit WAN2.1 aus dem Prompt: 'Ein Astronaut der ein Pferd auf dem Mond reitet' erzeugt.

Im Wesentlichen funktioniert der Trainingsprozess eines solchen Denoising Diffusion Probabalistic Models (DDPM) wie folgt. Einem Bild bzw. einem Satz von Bildern wird normal verteiltes Rauschen zugefügt. Anschließend wird ein Transformer Netzwerk oder ein faltendes Netzwerk (CNN) darauf trainiert, das Rauschen zu entfernen. Schritt für Schritt wird dem Modell weiter Rauschen zugefügt, und das Modell dadurch verfeinert. Der Prozess setzt sich fort, bis das Bild vollständig verrauscht ist.


Was einfach klingt, erweist sich im Detail doch schwierig. Würden wir einem Bildvektor \(x_0 \) lediglich normal verteiltes Rauschen \(\epsilon = \mathcal{N}(0, 1)\) mit Erwartungswert \(\mu=0\) und Standardabweichung \(\sigma^2=1\) der Stärke \(\beta\) gemäß
\[x_1 = x_0 + \beta \cdot \epsilon\]
zufügen, dann würde sich nach \(n\) Schritten 

\[x_n = x_0 + n \cdot \beta \cdot \epsilon\]

ergeben. Nach sehr vielen Schritten entspricht die Werteverteilung der Pixel dann einer Normalverteilung  \(\mathcal{N}(\bar{x_0}, n\cdot\beta)\)  um \(\bar{x_0}\) und unendlich großer Standardabweichung \(n\cdot\beta\). 
Daß die Standardabweichung mit n skaliert kann man exemplarisch mit einem einfachen Beispiel plausibel machen.

import numpy as np
# Zufallswerte generieren
x_0=norm.rvs(loc=0, scale=1, size=1000)
epsilon = norm.rvs(loc=0, scale=1, size=1000)
beta =10
x_1=x_0+beta*epsilon
# Standardabweichung und Varianz berechnen
sigma = np.std(x_1, ddof=0) #	Standardabweichung der Stichprobe, „degrees of freedom“ = 1, sonst ddof=0
print(f"Standardabweichung (σ): {sigma:.4f}")
varianz = np.var(x_1)
print(f"Varianz: {varianz:.2f}")

Für zwei Normalverteilungen um 0 mit Standardabweichung 1 und \(\beta = 10\) liefert das script eine Standardabweichung von  9.8116 und eine Varianz von 96.27. Das ist offensichtlich nicht was wir wollen. Ziel ist es das der Grenzwert der bedingten Verteilung \(q(x_n \mid x_0)\) also der Verteilung von \(x_n\) bei gegebenem \(x_0\) zu einer Normalverteilung  \(\mathcal{N}(0, 1)\) um 0 mit Standardabweichung 1 wird. Es soll also

\[q(x_n \mid x_0) \xrightarrow{n \to \infty} \mathcal{N}(0, 1)\]

gelten. Um dies zu erreichen definieren wir

$$x_n  = (1 - \beta) \cdot x_{n-1} + \beta \cdot \epsilon$$

Wie eine einfache Rechnung zeigt ergibt sich \(x_n\) in Abhängigkeit von \(x_0\) dann zu

$$x_n = \alpha^n \cdot x_0 + (1 - \alpha^n) \cdot \epsilon$$ mit \(\alpha_n=1-\beta\). Wird \(\beta\) zwischen 0 und 1 gewählt, geht \(\alpha\) und damit der Erwartungswert von \(x_n\) gegen 0 und die Standardabweichung gegen 1. 

Durch Verwendung eines 'noise schedules' aus N Schritten können auch für jeden Iterationsschritt unterschiedliche \(\alpha\) -Werte gewählt werden.  Für \(x_N\) gilt dann ganz allgemein:

\[x_N = \prod_{i=1}^{N} \alpha_i \cdot x_0 + (1-\prod_{i=1}^{N}\alpha_i )\epsilon\]

Da die Standardabweichung die Wurzel der Varianz darstellt, wird in vielen Anwendungsbeispielen ein wurzelförmiger 'schedule' gewählt. Der Ausdruck \((1-\prod_{i=1}^{N}\alpha_i )\) steht dann für die Entwicklung der Varianz und nicht der Standardabweichung.

In meinem DDPM Bespiel habe ich das 'nn' package von PyTorch verwendet. Die 'add noise' Funktion ist wie folgt definiert.

def add_noise(x, t, noise_schedule):
    noise = torch.randn_like(x) # creates a new tensor with the same shape, data type, layout, and device as the input tensor x, but fills it with random values sampled from a standard normal distribution (mean = 0, standard deviation = 1)
    alpha = noise_schedule[t].sqrt().reshape(-1, 1, 1, 1) #reshape into 1 'column' of 3 dimensional arrays 
    return alpha * x + (1 - alpha) * noise

Um das Rauschen nach jedem Iterationsschritt zu entfernen, habe ich ein einfaches Convolutional Neural Netweork (CNN) programmiert.

class SimpleDDPM(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), # applies 64 filters (each of size 3×3), producing 64 feature maps.
            nn.Conv2d(64, 1, 3, padding=1)
        )
    def forward(self, x, t):
        return self.model(x)

Die Conv2d-Schicht in torch legt keine konkreten Filter wie „Kantenerkennung“ oder „Mustererkennung“ fest – sie lernt die Filterparameter von hier 64 Filtern während des Trainings. Das neuronale Netzwerk versucht also das rauschfreie Bild zu erzeugen. In meinem Anwendungsbeispiel soll das neuronale Netzwerk mit dem MNIST Datensatz trainiert werden, um später handschriftliche Ziffern erzeugen zu können. 

Ich habe das gesamte Programm als Kaggle notebook angelegt. Nach dem Einlesen der notwendigen Bibliotheken, werden einige globale Variablen festgelegt, und der MNIST Datensatz geladen. Um Rechenzeit zu sparen skaliere ich die Bilder auf eine Größe von 14x14 Pixel und lade nur 3000 Bilder. 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
from tqdm import tqdm

# define hyperparameter
pic_size = 14 # size of pictures
train_size = 3000 # size of dataset
# noiseschedule
timesteps = 50 # number of timesteps
noise_schedule = torch.linspace(1e-4, 1, timesteps) #e.g. 100 steps in between 1e-4 and 1
# training
epochNr=100 # numer of iterations
lrate=1e-5 # learning rate

# load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Resize((pic_size, pic_size))]) #transform to tensor and downscale to e.g. 14x14
mnist = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
subset_indices = list(range(train_size)) #load e.g. 3000 pictures only
subset = Subset(mnist, subset_indices)
data_loader = DataLoader(subset, batch_size=64, shuffle=True) #data will be loaded in batches of 64 samples

# get some samples of the dataset
images, labels = next(iter(data_loader))
# and plot the samples
fig, axes = plt.subplots(1, 6, figsize=(10, 2))
for i in range(6):
    axes[i].imshow(images[i][0], cmap="gray")
    axes[i].set_title(f"Label: {labels[i].item()}")
    axes[i].axis("off")
plt.tight_layout()
plt.show()

Die reduzierten Bilder sehen dann exemplarisch wie folgt aus.


Nach der bereits bekannten Definition der 'add_noise' Funktion und des neuronalen Netzwerks selbst, habe ich eine Funktion eingefügt, die es erlaubt den Vorgang des Entfernens von Rauschen bildlich zu verfolgen. 

# define forward diffusion (adding noise)
def add_noise(x, t, noise_schedule):
    noise = torch.randn_like(x) # create a new tensor with the same shape, data type, layout, and device as the input tensor x, and fill it with random values sampled from a standard normal distribution (mean = 0, standard deviation = 1)
    alpha = noise_schedule[t].sqrt().reshape(-1, 1, 1, 1) # reshape into 1 'column' of 3 dimensional arrays
    return alpha * x + (1 - alpha) * noise
 
# define the diffusion model (CNN)
class SimpleDDPM(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), #appla 64 filters (each of size 3×3), producing 64 feature maps
            nn.Conv2d(64, 1, 3, padding=1)
        )
    def forward(self, x, t):
        return self.model(x)
        
# visualize denoising over time
def visualize_denoising(model, x, timesteps_to_show, noise_schedule):
    model.eval()  # switch to eval mode
    with torch.no_grad():
        fig, axes = plt.subplots(1, len(timesteps_to_show), figsize=(15, 3))
        for i, t_val in enumerate(timesteps_to_show):
            t = torch.full((x.size(0),), t_val, dtype=torch.long)
            noisy_x = add_noise(x, t, noise_schedule)  # add noise
            pred_x = model(noisy_x, t)  # model predicts denoised image
            
            axes[i].imshow(pred_x[0][0].cpu(), cmap="gray")
            axes[i].set_title(f"Timestep: {t_val}")
            axes[i].axis("off")
        
        plt.tight_layout()
        plt.show()
    model.train()  # switch back to training mode

Die Sequenz unten zeigt in der ersten Zeile wie das Rauschen nach 20 Iterationen in Zeitschritten von 0 bis 50 aus einem ausgewählten Bild entfernt wird. Die zweite Zeile zeigt die gleichen Zeitschritte nach 40 Iterationen, die dritte nach 60 Iterationen, und die letzte Zeile nach Beendigung des Trainings von 100 Iterationen


Im Hauptteil des Programms wird das Training ausgeführt, und die 'loss' Funktion geplottet.
# train the model
model = SimpleDDPM()
optimizer = torch.optim.Adam(model.parameters(), lr=lrate) #adam optimization algorithm at defined learning rate
epochs = epochNr
losses = [] # array for loss values

for epoch in range(epochs):
    for x, _ in tqdm(data_loader):
        x = x 
        t = torch.randint(0, timesteps, (x.size(0),))

        noisy_x = add_noise(x, t, noise_schedule)
        noise_pred = model(noisy_x, t)

        loss = F.mse_loss(noise_pred, x)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    losses.append(loss.item())  # store loss from this epoch
    
    if (epoch + 1) % 20 == 0: # show denoising for every 20th epoch
        print(f"Epoch {epoch+1} Loss: {loss.item():.4f}")
        # pick one image from your dataset
        sample_x, _ = next(iter(data_loader))
        sample_x = sample_x[:1]  # use a single image
        
        timesteps_to_show = [0, timesteps//4, timesteps//2, 3*timesteps//4, timesteps-1]  # show 5 timesteps
        visualize_denoising(model, sample_x, timesteps_to_show, noise_schedule)

plt.plot(range(1, epochs+1), losses, marker='o')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss over Epochs")
plt.grid(True)
plt.show()

Das Bild unten zeigt die loss Kurve exemplarisch nach 200 Iterationen bei einer Lernrate von 1e-5.

Um nun ein Bild einer Ziffer generieren zu können muss das Netzerk erneut durchlaufen werden.

# define how to generate images from pure noise using a trained DDPM model.
def sample_images_from_ddpm(model, noise_schedule, pic_size=14, timesteps=50, num_samples=6):
    """
    Args:
        model: Trained DDPM model.
        noise_schedule: Schedule of noise levels (torch tensor).
        pic_size: Size of generated images (e.g. 14).
        timesteps: Number of timesteps used in training/sampling.
        num_samples: Number of images to generate.
    
    Returns:
        A tensor of shape (num_samples, 1, pic_size, pic_size) with generated images.
    """
    model.eval()
    sample = torch.randn(num_samples, 1, pic_size, pic_size)

    with torch.no_grad():
        for t in range(timesteps):
            t_tensor = torch.full((num_samples,), t, dtype=torch.long)
            pred = model(sample, t_tensor)

            alpha = noise_schedule[t] # alpha = noise_schedule[t].sqrt()
            alpha = alpha.view(1, 1, 1, 1)  # make dimensions broadcastable
            noise = 7*torch.randn_like(pred) #sigma = 2, mu = 0
            sample = alpha * pred + (1 - alpha) * noise  # reverse diffusion step

    return sample
    
# Generate 6 samples
samples = sample_images_from_ddpm(model, noise_schedule, pic_size, timesteps, num_samples=6)

# Plot the results
fig, axes = plt.subplots(1, 6, figsize=(12, 2))
for i in range(6):
    axes[i].imshow(samples[i][0].cpu(),cmap='gray')
    axes[i].axis("off")
plt.tight_layout()
plt.show()

Bei einem sigma Wert von 2 werden aus einem rein zufällig verteilten Rauschen Bilder erzeugt, die zumindest einer Überlagerung von einzelnen trainierten Ziffern ähneln. 


Klar definierte Ziffern lassen sich jedoch nicht generieren. Vermutlich liegt das daran, dass ich versucht habe das rauschfreie Bild direkt durch das neuronale Netzwerk zu erzeugen. Im originalen Ansatz von Jonathan Ho, Ajay Jain und Pieter Abbeel wird jedoch ein U-net darauf trainiert, nicht das Bild sondern das Rauschen zu ermitteln das von einem Bild 'subtrahiert' werden muss. 

Bis zum nächsten Artikel also, und viel Spaß beim selber programmieren...

Keine Kommentare:

Kommentar veröffentlichen