Im zweiten Versuch ein Denoising Diffusion Probabilistic Model zu programmieren folge ich einem Artikel von Jonathan Ho, Ajay Jain und Pieter Abbeel zu diesem Thema. Wie bereits im ersten Versuch erläutert, wird im Forward-Prozess schrittweise Rauschen zu einem Eingabebild \( x_0 \) hinzugefügt. Das Rauschen hat eine zeitabhängige Normalverteilung, wobei ein Parameter \( \beta_t \) die Intensität des Rauschens steuert.
Die Verteilung der 'Farbwerte' von \(x_t\) hängt nur von \(x_{t-1}\) ab. Es gilt:
$$q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} \cdot x_{t-1}, \beta_t \cdot I) $$
Sie besitzt einen Mittelwert von \(\sqrt{1 - \beta_t} \cdot x_{t-1}\) und eine Varianz von \(\beta_t \cdot I\) . Dabei ist \(I\) die Einheitsmatrix, weil die Störung gleichmäßig in jede Richtung erfolgen soll. Für alle Zeitschritte \( t \) gilt \( \beta_t < 1 \) . Für die Konfiguration des Forward-Prozesses habe ich eine Anzahl von \( T\) = 1000 Zeitschritten gewählt, wobei die Rauschparameter linear von \( \beta_1 = 10^{-4} \) bis \( \beta_T\) = 0,02 ansteigen. Die lineare Zeitplanung erfolgt nach der Formel
$$ \beta_t = \beta_1 + \frac{(t - 1)}{T} \cdot \beta_T $$
Für ein gegebenes Bild \( x_0 \) und die festgelegten \( \beta_t \) lassen sich alle Zwischenzustände \( x_t \) direkt berechnen, ohne dass man rekursiv über alle vorherigen Schritte gehen muss. Die entsprechende Verteilung lautet hierfür
$$ q(x_t | x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t} \cdot x_0, (1 - \bar{\alpha}_t) \cdot I) $$
wobei
$$ \alpha_t = 1 - \beta_t $$
und
$$ \bar{\alpha}_t =\prod_{s=1}^{t} \alpha_s $$
gilt. Für \( x_t \) selbst gilt in Abhängigkeit von \( x_0 \):
$$ x_t = \sqrt{\bar{\alpha}_t} \cdot x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon $$
\( \epsilon \sim \mathcal{N}(0, I) \) ist ein Zufallsrauschen. Der gesamte Vorwärts-Prozess enthält also keine trainierbaren Parameter. Er dient lediglich dazu, das ursprüngliche Bild systematisch in Rauschen zu überführen und bildet die Grundlage für das spätere Lernen im Reverse-Prozess. Im Programm wird der Vorwärtsprozess in der Klasse 'DiffusionForwardProcess' wie folgt definiert:
# Importe import torch import torch.nn as nn class DiffusionForwardProcess: """ Forward process as described in the paper “Denoising Diffusion Probabilistic Models". """ def __init__(self, num_time_steps = 1000, beta_start = 1e-4, beta_end = 0.02 ): # Vorberechnung von beta, alpha und alpha_bar für alle Zeitschritte t self.betas = torch.linspace(beta_start, beta_end, num_time_steps) # 1D-Tensor mit gleichmäßig verteilten Werten zwischen beta_start und beta_end self.alphas = 1 - self.betas # alpha = 1 - beta self.alpha_bars = torch.cumprod(self.alphas, dim=0) # Kumulatives Produkt: z. B. [α₁, α₁·α₂, α₁·α₂·α₃, ...] self.sqrt_alpha_bars = torch.sqrt(self.alpha_bars) # Quadratwurzel von alpha_bar self.sqrt_one_minus_alpha_bars = torch.sqrt(1 - self.alpha_bars) # Quadratwurzel von (1 - alpha_bar) def add_noise(self, original, noise, t): # Hinweis: alle Argumente müssen Tensoren sein """ Fügt einem Batch von Originalbildern zum Zeitpunkt t Rauschen hinzu. :parameter original: Eingabebild als Tensor :parameter noise: Zufallsrauschen aus der Normalverteilung N(0, 1) :parameter t: Zeitschritt des Forward-Prozesses, shape: (B,) Hinweis: Der Zeitschritt t kann für jedes Bild im Batch unterschiedlich sein. """ sqrt_alpha_bar_t = self.sqrt_alpha_bars.to(original.device)[t] # Auf GPU übertragen, falls das Originalbild auf der GPU liegt sqrt_one_minus_alpha_bar_t = self.sqrt_one_minus_alpha_bars.to(original.device)[t] # Broadcasting zur Multiplikation mit dem Originalbild sqrt_alpha_bar_t = sqrt_alpha_bar_t[:, None, None, None] # Zusätzliche Dimensionen für (B, C, H, W) sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t[:, None, None, None] # Rückgabe des verrauschten Bildes return (sqrt_alpha_bar_t * original) + (sqrt_one_minus_alpha_bar_t * noise) # Jeder Pixel des Originalbildes wird mit sqrt_alpha_bar_t multipliziert und das Rauschen addiert. # Falls t z. B. vier Zeitschritte enthält, werden vier Bilder mit entsprechendem Rauschen erzeugt.
Untenstehender Testaufruf gibt 4 unterschiedliche Bilder zu 4 unterschiedlichen Zeitpunkten zurück.
# Test original = torch.randn(4,1, 28, 28) # batch aus 4 Bildern mit Kanalzahl 1, Höhe 28 und Breite 28 noise = torch.randn(1, 28, 28) t_steps = torch.randint(0, 1000, (4,)) #1D Tensor aus 4 Zufallszahlen zwischen 0 und 1000 # Forward Process dfp = DiffusionForwardProcess() #Instanz der Klasse erzeugen out = dfp.add_noise(original, noise, t_steps) #gibt 4 unterschiedliche Bildern zu 4 unterschiedlichen Zeitpunkten zurück print(out.shape)
Die Dimensionen des verrauschten Bildbatches sind mit (4,1,28,28) identisch zum ursprünglichen batch.
Der Reverse-Prozess eines DDPM lässt sich als die Umkehrung des Forward-Prozesses verstehen. Ziel ist es, ein vollständig verrauschtes Bild \( x_T \sim \mathcal{N}(0, I) \) über \( T \) Zeitschritte hinweg schrittweise zu entrauschen, bis das ursprüngliche Bild rekonstruiert ist. Während der Forward-Prozess deterministisch ist und keine trainierbaren Parameter enthält, wird im Reverse-Prozess typischerweise ein U-Net eingesetzt, um das Rauschen \( \epsilon_\theta \) zu jedem Zeitpunkt zu schätzen. Die Rekonstruktion von \( x_{t-1} \) aus \( x_t \) erfolgt über die Gleichung
$$ x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \cdot \epsilon_{\theta} \right) + \sigma_t z $$
wobei \( z \sim \mathcal{N}(0, I) \) ein zusätzlicher Zufallsvektor ist, der die stochastische Natur der Diffusion bewahrt. Zusätzlich lässt sich zu jedem Zeitpunkt eine Schätzung des ursprünglichen Bildes \( x_0 \) berechnen, und zwar durch
$$ x_0 = \frac{1}{\sqrt{\bar{\alpha}_t}} \left( x_t - \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon_{\theta} \right) $$
Diese Formeln bilden die Grundlage des Reverse-Prozesses, bei dem das Modell lernt, aus reinem Rauschen ein realistisches Bild zu generieren. Die Qualität der Rekonstruktion hängt dabei entscheidend von der Güte der Rauschvorhersage \( \epsilon_{\theta} \) ab, die wir im nächsten Schritt durch ein trainierbares U-Net-Modell implementieren werden.
Der Rückwärtsprozess wird in der unterstehenden Klasse 'DiffusionReverseProcess' definiert.
class DiffusionReverseProcess: """ Reverse process as described in the paper “Denoising Diffusion Probabilistic Models". """ def __init__(self, num_time_steps = 1000, beta_start = 1e-4, beta_end = 0.02 ): # Vorberechnung von beta, alpha und alpha_bar für alle Zeitschritte t self.b = torch.linspace(beta_start, beta_end, num_time_steps) # b → beta self.a = 1 - self.b # a → alpha self.a_bar = torch.cumprod(self.a, dim=0) # a_bar = alpha_bar (kumulatives Produkt) def sample_prev_timestep(self, xt, noise_pred, t): """ Berechnet x_{t-1} aus x_t und dem vom Modell vorhergesagten Rauschen. :param xt: Bildtensor zum Zeitpunkt t, Form → B x C x H x W :param noise_pred: Vom Modell vorhergesagtes Rauschen, gleiche Form wie xt :param t: Aktueller Zeitschritt """ # Schätzung des ursprünglichen Bildes x_0 zum Zeitpunkt t x0 = xt - (torch.sqrt(1 - self.a_bar.to(xt.device)[t]) * noise_pred) x0 = x0 / torch.sqrt(self.a_bar.to(xt.device)[t]) x0 = torch.clamp(x0, -1., 1.) # Begrenze Werte von x0 auf den Bereich [-1, 1] # Berechnung des Erwartungswerts (Mittelwert) von x_{t-1} mean = (xt - ((1 - self.a.to(xt.device)[t]) * noise_pred) / torch.sqrt(1 - self.a_bar.to(xt.device)[t])) mean = mean / torch.sqrt(self.a.to(xt.device)[t]) # Falls t = 0, gib nur den Mittelwert zurück (kein Sampling mehr nötig) if t == 0: return mean, x0 else: # Berechnung der Varianz für den Übergang von t zu t-1 variance = (1 - self.a_bar.to(xt.device)[t-1]) / (1 - self.a_bar.to(xt.device)[t]) # Verhältnis der Rauschanteile zwischen zwei aufeinanderfolgenden Zeitpunkten variance = variance * self.b.to(xt.device)[t] # beta ist die Varianz des Forward-Prozesses im Schritt t, also wie viel Rauschen hinzugefügt wurde sigma = variance**0.5 z = torch.randn(xt.shape).to(xt.device) # Zufallsrauschen z ~ N(0, I) # Rückgabe von x_{t-1} (gesampelt) und der Schätzung von x_0 return mean + sigma * z, x0
Ein Testaufruf mit einem zum Zeitpunkt \(t\) verrauschten Bild der Dimension (1, 1, 28, 28) , liefert das Bild zum Zeitpunkt \(t-1\) und das ursprüngliche Bild mit der Dimension (1, 1, 28, 28) zurück.
# Test image = torch.randn(1, 1, 28, 28) noise_pred = torch.randn(1, 1, 28, 28) t = torch.randint(0, 1000, (1,)) # Reverse Process drp = DiffusionReverseProcess() out, x0 = drp.sample_prev_timestep(image, noise_pred, t) print(out.shape) print(x0.shape)
In Diffusionsmodellen ist die Zeitinformation von zentraler Bedeutung, da das Modell zu jedem Zeitschritt wissen muss, in welcher Phase der Rauschvorhersage es sich befindet. Um diese zeitliche Information explizit zu codieren, greift man auf eine Technik zurück, die ursprünglich aus dem Positional Encoding in Transformer-Modellen stammt. Dabei wird der Zeitschritt \( t \) in einen kontinuierlichen Vektorraum eingebettet, sodass das Modell ihn als Eingabe verarbeiten kann. Die Einbettung erfolgt über sinus- und cosinusbasierte Funktionen, die für jede Dimension des Embedding-Vektors eine unterschiedliche Frequenz verwenden. Formal wird die Einbettung \(PE\) von \(0\) bis \(d_{\text{emb}} / 2\) bzw. von \(d_{\text{emb}} / 2\) bis \(d_{\text{emb}}\) wie folgt berechnet:
$$ PE(pos, 0:d_{\text{emb}} / 2) = \sin\left(\frac{pos}{10000^{2i / d_{\text{emb}}}}\right)$$
$$ PE(pos, d_{\text{emb}} / 2:d_{\text{emb}}) = \cos\left(\frac{pos}{10000^{2i / d_{\text{emb}}}}\right) $$
Dabei bezeichnet \( pos \) den Zeitschritt \( t \), \( d_{\text{emb}} \) die Länge des Embedding-Vektors (z. B. 128), und \( i \) den Index innerhalb der jeweiligen Hälfte des Vektors. Der Ausdruck \( 10000^{2i / d_{\text{emb}}} \) definiert die Periodenbasis und sorgt dafür, dass jede Dimension eine andere Frequenz besitzt – so entsteht eine vielfältige, kontinuierliche Repräsentation der Zeit. Diese zeitliche Einbettung wird anschließend in das neuronale Netz eingespeist. Die 'get_time_embedding' Funktion ist folgendermaßen definiert:
def get_time_embedding( time_steps: torch.Tensor, t_emb_dim: int ) -> torch.Tensor: # """ Definiert die Funktion 'get_time_embedding, die zwei Parameter erwartet: # 1. time_steps: einen 1D torch Tensor mit B vielen Zeitpunkten (z. B. [0, 1, 2, ...]) # 2. t_emb_dim: die gewünschte Länge des Embeddings (z. B. 128) # Gibt einen torch Tensor mit B vielen Zeilen und t_emb_dim Spalten zurück # """ assert t_emb_dim%2 == 0, "time embedding must be divisible by 2." #Fehlermeldung wenn t_emb_dim nicht durch 2 ohne Rest teilbar ist factor = 2 * torch.arange(start = 0, end = t_emb_dim//2, dtype=torch.float32, device=time_steps.device) / (t_emb_dim) #Füllt einen t_emb_dim//2 langen torch Tensor mit Werten von 0 bis t_emb_dim//2-1 skaliert auf 2/t_emb_dim factor = 10000**factor #factor = 10000^factor. Der Faktor ist immer < 1. Die Faktoren fungieren als Frequenzen für die Zeitschritte. 'Spätere' time_steps bekommen kleinere Frequenzen t_emb = time_steps[:,None] # B -> (B, 1), erzeugt eine Spalte t_emb = t_emb/factor # (B, 1) -> (B, t_emb_dim//2) #erzeugt eine Matrix mit time_step_0/factor_0 bis time_step_0/factor_(t_emb_dim//2-1)in der ersten Zeile und time_step_(B-1)//factor_0 bis time_step_(B-1)/factor_(t_emb_dim//2-1) in der letzen Zeile t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=1) # (B , t_emb_dim). Sinus und Cosinus wird berechnet. Die Ergebnisse werden entlang der zweiten Achse (dim=1) zusammengefügt d.h. Spalten werden angehängt). Ergebnis: Ein Tensor der Form (B, t_emb_dim) hier (128,128) – die eine Hälfte Sinus, die andere Hälfte Kosinus. return t_emb
Die zentrale Aufgabe des Modells besteht darin, den Rauschanteil eines Bildes zu verschiedenen Zeitpunkten vorherzusagen. Da sowohl Eingabe als auch Ausgabe zweidimensionale Bilder sind, eignet sich eine U-Net Architektur für diese Art von Bild-zu-Bild-Übersetzung. Es besteht aus drei Hauptkomponenten:
1. Downsampling Einheit oder auch Encoder genannt: Hier werden die räumlichen Dimensionen durch aufeinanderfolgende Convolution-Blöcke und Pooling-Operationen reduziert, während die Anzahl der Feature-Kanäle steigt. Dadurch werden die relevanten Merkmale extrahiert.
2. Bottleneck Einheit oder Mittlerer Block: Dies ist die tiefste Ebene des Netzwerks mit der abstraktesten Repräsentation. Es kommen Residual- oder Attention-Mechanismen zum Einsatz, die später noch genauer erklärt werden.
3. Upsampling Einheit auch Decoder genannt: Über Transponierte Convolution oder Interpolation wird die ursprüngliche räumliche Auflösung schrittweise wiederhergestellt. Durch sogenannte 'Skip Connections' zwischen Encoder und Decoder werden Details direkt übertragen und ermöglichen eine präzise Rekonstruktion.
Konkret wird die Eingabeschicht unseres U-Netzwerks also z.B. mit einem verrauschten Bild vom shape (1,1,28,28) zum Zeitpunkt t gefüttert, und die letzte Schicht liefert ein 'vorhergesagtes' Rauschen \(\epsilon_{\Theta}\) vom gleichem shape (1,1,28,28).
Die ASCII Grafik unten zeigt die verwendete UNet Struktur. Abkürzungen stehen für Hilfsmodule die im folgenden genauer erklärt werden.
initial_conv +
Down = 3x(DownC.forward) mit num_DownC_layers =2: (conv1 + te + conv2 + attn + dspl) + (conv1 + te + conv2 + attn + dspl) ------------------------ |----- skip -----| |----- skip -----| | | + (conv1 + te + conv2 + attn + dspl) + (conv1 + te + conv2 + attn + dspl) --------------------- | |----- skip -----| |----- skip -----| | | | | + (conv1 + te + conv2 + attn ) + (conv1 + te + conv2 + attn ) ------------------ | | |----- skip -----| |----- skip -----| | | | | | | | | | Mid = 3x(MidC.forward) mit num_MidC_layers=2: | | | | | | + conv1 + te + conv2 + (attn + conv1 + te + conv2) + (attn + conv1 + te + conv2) | | | |----- skip -----| |----- skip -----| |----- skip -----| | | | | | | + conv1 + te + conv2 + (attn + conv1 + te + conv2) + (attn + conv1 + te + conv2) | | | |----- skip -----| |----- skip -----| |----- skip -----| | | | | | | + conv1 + te + conv2 + (attn + conv1 + te + conv2) + (attn + conv1 + te + conv2) | | | |----- skip -----| |----- skip -----| |----- skip -----| | | | | | | | | | Up = 3x(UpC.forward) mit num_UpC_layers=2: | | | | | | |------------------------------------------------------------------- skip -------- | | + ( conv1 + te + conv2 + attn) + ( conv1 + te + conv2 + attn) | | |----- skip -----| |----- skip -----| | | | | |------------------------------------------------------------------- skip ---------- | + (uspl + conv1 + te + conv2 + attn) + (uspl + conv1 + te + conv2 + attn) | |----- skip -----| |----- skip -----| | | |------------------------------------------------------------------- skip -------------- + (uspl + conv1 + te + conv2 + attn) + (uspl + conv1 + te + conv2 + attn) |----- skip -----| |----- skip -----|
+ final_conv
'Conv(olution)' steht dabei für die Klasse 'NormActConv'. Eine Instanz dieser Klasse führt eine Normierung, anschließend eine Aktivierung und dann eine Faltung aus. Die Normierung erfolgt mit 'GroupNorm'. GroupNorm teilt die Kanäle in mehrere Gruppen (z. B. 8 Gruppen bei num_groups=8) und normalisiert jeden Kanal innerhalb seiner Gruppe. Das bedeutet: Für jede Gruppe wird der Mittelwert \(\mu_G\) und die Standardabweichung \(\sigma_G\) berechnet. Dann wird jeder Wert in der Gruppe wie folgt transformiert: \(x'=(x−\mu_G)/(\sigma_G+\epsilon)\). Dabei ist \(\epsilon\) eine kleine Zahl zur Vermeidung von Division durch Null.
class NormActConv(nn.Module): """ Performs Normalization, Activation and Convolution
""" def __init__(self, in_channels:int, # Anzahl der Eingangskanäle out_channels:int, # Anzahl der Ausgangskanäle num_groups:int = 8, # Anzahl der Gruppen für GroupNorm (default = 8) kernel_size: int = 3, # Kernelgröße für Conv2D (default = 3) norm:bool = True, # GroupNorm wenn gesetzt act:bool = True # SiLU Aktivierung x * sigmoid(x) wenn gesetzt ): super(NormActConv, self).__init__()
# GroupNorm self.g_norm = nn.GroupNorm(num_groups,in_channels) if norm is True else nn.Identity() # Activation self.act = nn.SiLU() if act is True else nn.Identity() # Convolution self.conv = nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size - 1)//2 #Padding wird so gewählt, dass die räumliche Größe erhalten bleibt ) def forward(self, x): x = self.g_norm(x) x = self.act(x) x = self.conv(x) return x
'te' steht für 'TimeEmbeding', und beschreibt wie die Zeit-Einbettung ins Modell integriert wird. Zuerst wird die Eingabe mit der SiLU-Aktivierung geglättet, anschließend durch eine lineare Schicht in die gewünschte Ausgabedimension projiziert. Dazu wird ein linearer layer mit \(x' = x\cdot A + b\) eingefügt, wobei \(x\) die Eingabe mit t_emb_dim Dimensionen, \(x'\) die Ausgabe in n_out Dimensionen, \(A\) wie üblich eine zu trainierende Gewichtsmatrix und \(b\) der Bias ist.
class TimeEmbedding(nn.Module): """ Maps the Time Embedding to the Required output Dimension. """ def __init__(self, n_out:int, # Output Dimension t_emb_dim:int = 128 # Time Embedding Dimension ): super(TimeEmbedding, self).__init__() # Time Embedding Block self.te_block = nn.Sequential( nn.SiLU(), # Aktivierung nn.Linear(t_emb_dim, n_out) # 'embeding' durch Einfügen einer linearen Schicht )
def forward(self, x): return self.te_block(x)
'attn' steht für 'SelfAttention', und erlaubt es dem Modell, nicht-lokale Beziehungen zu lernen, also wie z. B. ein Objekt in der oberen linken Ecke eines Bildes mit einem Objekt unten rechts zusammenhängt. Eine einzelne Aufmersamkeitsfunktion besteht aus einer Abfrage \(Q\), einem Schlüssel \(K\) und einem Wert \(V\). Der Schlüssel \(K\) stellt dabei eine Beziehungsmatrix der Bildausschnitte, oder von Wörtern zueinander dar. Der Wert \(V\) repräsentiert die Bedeutung eines Wortes bzw. den Wertvektor eines Bildausschnittes. Die Selbstaufmerksamkeit wird dann wie folgt berechnet:
$$SelfAttention(Q,K,V) = Softmax(QK^T/\sqrt{d_k})V$$
Das Produkt \(QK^T\) misst, wie ähnlich das aktuelle Wort (Query) zu anderen Wörtern ist. Die Ähnlichkeitswerte werden durch die Softmax-Funktion in Wahrscheinlichkeiten umgewandelt. Damit die Werte nicht zu groß werden, teilt man durch \(\sqrt{𝑑_𝑘}\), wobei \(d_k\) die Dimension der Key-Vektoren ist. Diese Wahrscheinlichkeiten werden auf die Value-Vektoren \(V\) angewendet. So entsteht eine gewichtete Summe über alle 'values'. Bei der multihead Self Attention Operation wird die Eingabe X unter Verwendung unterschiedlicher Gewichtungsmatrizen W in mehrere kleinerdimensionale Unterräume projiziert.
$$Q_i = X\cdot W^Q_i, K_i = X\cdot W^K_i, V_i = X\cdot W^V_i $$
wobei \(i\) den Kopfindex bezeichnet. Jeder Kopf berechnet unabhängig voneinander seine eigene Selbstaufmerksamkeit mit der skalierten Punktproduktformel. Die Ausgänge aller Köpfe werden anschließend verkettet.
class SelfAttentionBlock(nn.Module): """ Perform GroupNorm and Multiheaded Self Attention operation. """
def __init__(self, num_channels:int, num_groups:int = 8, num_heads:int = 4, norm:bool = True ): super(SelfAttentionBlock, self).__init__() # GroupNorm self.g_norm = nn.GroupNorm(num_groups,num_channels) if norm is True else nn.Identity() # Self-Attention self.attn = nn.MultiheadAttention(num_channels,num_heads, batch_first=True) #batch_first=True bedeutet, dass die Eingabe die Form (B, SeqLen, C) hat. def forward(self, x): batch_size, channels, h, w = x.shape x = x.reshape(batch_size, channels, h*w) # """ # Wandelt das 'Bild' oder ein batch von Punkten oder 'Token' in eine Sequenz der Dimension # (B, C, H*W). Jeder räumliche Punkt bei einem Bild (zB. die Kanäle R, G und B' mit # Dimension C=3) wird als ein 'Token' oder Textstück betrachtet. # """ x = self.g_norm(x) # Normalisiert die Kanäle gruppenweise x = x.transpose(1, 2) # neue Form: (B, H*W, C) – notwendig für MultiheadAttention x, _ = self.attn(x, x, x) # Führt die eigentliche MultiHeadSelfAttention durch. Query, Key und Value input sind gleich. Die zurückgegebene Gewichtungsmatrix wird ignoriert. x = x.transpose(1, 2).reshape(batch_size, channels, h, w) # Transponiert zurück zu (B, C, H*W) und formt dann zurück zu (B, C, H, W). return x
'dspl' steht für 'Downsample', und reduziert die räumliche Auflösung der 'Feature-Maps' um den Faktor \(k\). Die Klasse kombiniert 'strided Convolution' und 'Max-Pooling' und passt die Kanalanzahl an.
class Downsample(nn.Module): """ Perform Downsampling by the factor of k across Height and Width. """ def __init__(self, in_channels:int, out_channels:int, k:int = 2, # Downsampling Factor use_conv:bool = True, # If Downsampling using conv-block use_mpool:bool = True # If Downsampling using max-pool ): super(Downsample, self).__init__() self.use_conv = use_conv self.use_mpool = use_mpool # Downsampling using Convolution self.cv = nn.Sequential( nn.Conv2d(in_channels, in_channels, kernel_size=1), #Eine 1×1-Conv, der die Kanäle nicht verändert, aber als Feature-Transformation dient. nn.Conv2d(in_channels, out_channels//2 if use_mpool else out_channels, kernel_size=4, stride=k, padding=1) #Eine 4×4-Conv mit Stride k (Filter bewegt sich in Schritten von k Pixeln), der die räumliche Größe reduziert. Wenn use_mpool=True, wird die Anzahl der Ausgabekanäle halbiert (out_channels//2), da der max-pool Zweig auch Features liefert. ) if use_conv else nn.Identity() # Downsampling using Maxpool self.mpool = nn.Sequential( nn.MaxPool2d(k, k), #Downsampling durch Pooling mit Fenstergröße und Stride k. nn.Conv2d(in_channels, out_channels//2 if use_conv else out_channels, kernel_size=1, stride=1, padding=0) #1×1-Conv zur Kanaltransformation nach dem Pooling. ) if use_mpool else nn.Identity() def forward(self, x): if not self.use_conv: # Nur MaxPool-Zweig wird verwendet. return self.mpool(x) if not self.use_mpool: # Nur Convolution-Zweig wird verwendet. return self.cv(x) return torch.cat([self.cv(x), self.mpool(x)], dim=1) # cv und mpool werden entlang der Kanalachse (dim=1) zusammengefügt.
'uspl' steht für 'UpSample', und vergrößert hingegen Bildmerkmale um den Faktor \(k\). Das erfolgt per transponierter Faltung, bilinearer Interpolation oder beidem.
class Upsample(nn.Module): """ Perform Upsampling by the factor of k across Height and Width """ def __init__(self, in_channels:int, out_channels:int, k:int = 2, # Upsampling Faktor use_conv:bool = True, # Upsampling durch 'conv-block' (transposed convolution) wenn gesetzt use_upsample:bool = True # Upsampling durch 'upsample' (bilinear upsampling) wenn gesetzt ): super(Upsample, self).__init__() self.use_conv = use_conv self.use_upsample = use_upsample # Upsampling using conv self.cv = nn.Sequential( nn.ConvTranspose2d(in_channels,out_channels//2 if use_upsample else out_channels, kernel_size=4, stride=k, padding=1), # Transposed Convolution fungiert als 'de convolution'. Stride=k sorgt dafür, dass die räumliche Größe um den Faktor k vergrößert wird. nn.Conv2d(out_channels//2 if use_upsample else out_channels, out_channels//2 if use_upsample else out_channels, kernel_size = 1, stride=1, padding=0) # Ein 1×1-Conv zur Kanalmischung nach dem Upsampling. ) if use_conv else nn.Identity() # Upsamling using nn.Upsample self.up = nn.Sequential( nn.Upsample(scale_factor=k, mode = 'bilinear', align_corners=False), #Vergrößert die Höhe und Breite um den Faktor k durch Interpolation. mode='bilinear' ist eine gängige Methode für glatte Ergebnisse. nn.Conv2d(in_channels,out_channels//2 if use_conv else out_channels, kernel_size=1, stride=1, padding=0) # Auch hier ein 1×1-Conv zur Anpassung der Kanalanzahl. ) if use_upsample else nn.Identity() def forward(self, x): if not self.use_conv: return self.up(x) # Nur bilineares Upsampling wird verwendet. if not self.use_upsample: return self.cv(x) # Nur Transposed Convolution wird verwendet. return torch.cat([self.cv(x), self.up(x)], dim=1) # Beide Methoden werden kombiniert, und die Ergebnisse entlang der Kanalachse (dim=1) zusammengefügt. Dadurch entsteht ein Feature-Map mit out_channels Kanälen.
Damit sind alle Hilfsklassen definiert.
Um das U-Netz modular zu gestalten werden nun die Klassen 'DownC', 'MidC' und 'UpC' definiert. `DownC` verarbeitet die Eingabeschicht durch wiederholte Convolution, Zeit-Embedding, Self-Attention und anschließendes Downsampling. Wie in einem residual network (ResNet), werden auch sogenannte 'skip connections' verwendet. Ein allgemeiner residual block in einem ResNet sieht wie folgt aus:
Über eine 'skip connection' wird der Eingang eines Teilnetzes mit dessen Ausgang verbunden. Dadurch wird das bei tiefen Transformerstrukturen auftretende Problem des verschwindenden Gradienten abgeschwächt. In der Tat können sehr tiefe Transformerstrukturen nicht ohne 'skip connections' trainiert werden
class DownC(nn.Module): """ Perform Down-convolution on the input using following general approach. 1. Convolution 2. TimeEmbedding 3. Convolution 4. Self-Attention 5. Downsampling """ def __init__(self, in_channels:int, out_channels:int, t_emb_dim:int = 128, # Time Embedding Dimension num_layers:int=2, down_sample:bool = True # True for Downsampling ): super(DownC, self).__init__() self.num_layers = num_layers # conv1: Erste Convolution vor Zeit-Embedding self.conv1 = nn.ModuleList([NormActConv(in_channels if i==0 else out_channels, out_channels) for i in range(num_layers)]) # conv2: Zweite Convolution nach Addition des Embeddings self.conv2 = nn.ModuleList([NormActConv(out_channels, out_channels) for _ in range(num_layers)]) # Zeit-Embedding self.te_block = nn.ModuleList([TimeEmbedding(out_channels, t_emb_dim) for _ in range(num_layers)]) # Self-Attention self.attn_block = nn.ModuleList([SelfAttentionBlock(out_channels) for _ in range(num_layers)]) # optionales Downsampling self.down_block =Downsample(out_channels, out_channels) if down_sample else nn.Identity() # 1×1-Conv für die Skip-Connection self.res_block = nn.ModuleList([nn.Conv2d(in_channels if i==0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers)]) def forward(self, x, t_emb): out = x for i in range(self.num_layers): resnet_input = out # Schritt 1: Speichern des Inputs für Skip-Connection out = self.conv1[i](out) #Schritt 2: Erste Convolution out = out + self.te_block[i](t_emb)[:, :, None, None] #Schritt 3: Zeit-Embedding wird auf die räumliche Form „gebroadcasted“ (mit None erweitert) und hinzugefügt. out = self.conv2[i](out) # Schritt 4: Zweite Convolution out = out + self.res_block[i](resnet_input) # Schritt 5: Skip-Connection out_attn = self.attn_block[i](out) # Schritt 6: Self-Attention out = out + out_attn # Attention-Ergebnis wird hinzugefügt return out
Die Klasse `MidC` verfeinert Feature-Maps aus dem Downsampling-Pfad durch einen initialen ResNet-Block mit Zeit-Embedding, gefolgt von mehreren Self-Attention- und ResNet-Blöcken, die ebenfalls zeitliche Informationen integrieren.
class MidC(nn.Module): """ Refine the features obtained from the DownC block. It refines the features using following operations: 1. Resnet Block with Time Embedding 2. A Series of Self-Attention + Resnet Block with Time-Embedding """ def __init__(self, in_channels:int, out_channels:int, t_emb_dim:int = 128, num_layers:int = 2 ): super(MidC, self).__init__() self.num_layers = num_layers self.conv1 = nn.ModuleList([NormActConv(in_channels if i==0 else out_channels, out_channels) for i in range(num_layers + 1)]) self.conv2 = nn.ModuleList([NormActConv(out_channels, out_channels) for _ in range(num_layers + 1)]) self.te_block = nn.ModuleList([TimeEmbedding(out_channels, t_emb_dim) for _ in range(num_layers + 1)]) self.attn_block = nn.ModuleList([SelfAttentionBlock(out_channels) for _ in range(num_layers)]) self.res_block = nn.ModuleList([nn.Conv2d(in_channels if i==0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers + 1)]) def forward(self, x, t_emb): out = x # Erster-Resnet Block resnet_input = out out = self.conv1[0](out) out = out + self.te_block[0](t_emb)[:, :, None, None] out = self.conv2[0](out) out = out + self.res_block[0](resnet_input) # Abfolge von Self-Attention + Resnet Blocks for i in range(self.num_layers): # Self Attention out_attn = self.attn_block[i](out) out = out + out_attn # Resnet Block resnet_input = out out = self.conv1[i+1](out) out = out + self.te_block[i+1](t_emb)[:, :, None, None] out = self.conv2[i+1](out) out = out + self.res_block[i+1](resnet_input) return out
Die Klasse `UpC` führt eine mehrschichtige Up-Convolution durch, kombiniert Upsampling, zeitliche Einbettung, Residualverbindungen und Selbstaufmerksamkeit, um Feature-Maps effizient zu verarbeiten und mit Encoder-Ausgaben zu verknüpfen.
class UpC(nn.Module): """ Perform Up-convolution on the input using following approach. 1. Upsampling 2. Convolution 3. TimeEmbedding 4. Convolution 6. Self-Attention """ def __init__(self, in_channels:int, out_channels:int, t_emb_dim:int = 128, # Time Embedding Dimension num_layers:int = 2, up_sample:bool = True # Upsampling wenn Wahr ): super(UpC, self).__init__() self.num_layers = num_layers self.conv1 = nn.ModuleList([NormActConv(in_channels if i==0 else out_channels, out_channels) for i in range(num_layers)]) self.conv2 = nn.ModuleList([NormActConv(out_channels, out_channels) for _ in range(num_layers)]) self.te_block = nn.ModuleList([TimeEmbedding(out_channels, t_emb_dim) for _ in range(num_layers)]) self.attn_block = nn.ModuleList([SelfAttentionBlock(out_channels) for _ in range(num_layers)]) self.up_block =Upsample(in_channels, in_channels//2) if up_sample else nn.Identity() self.res_block = nn.ModuleList([nn.Conv2d(in_channels if i==0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers)]) def forward(self, x, down_out, t_emb): # Upsampling x = self.up_block(x) x = torch.cat([x, down_out], dim=1) out = x for i in range(self.num_layers): resnet_input = out # Resnet Block out = self.conv1[i](out) out = out + self.te_block[i](t_emb)[:, :, None, None] out = self.conv2[i](out) out = out + self.res_block[i](resnet_input) # Self Attention out_attn = self.attn_block[i](out) out = out + out_attn return out
Die Klasse `Unet` implementiert schließlich die U-Net-Architektur. Sie kombiniert Downsampling-, Mittelkern- und Upsampling-Blöcke mit Zeit-Embedding, Residualverbindungen und Selbstaufmerksamkeit zur Verarbeitung und Rekonstruktion von Bildmerkmalen. Im folgenden Code wurde die Zahl der 'DownC', 'MidC' und 'UpC' layer von 2 auf 1 reduziert um Rechenzeit zu sparen.
class Unet(nn.Module): """ U-net architecture which is used to predict noise. U-net consists of Series of DownC blocks followed by MidC followed by UpC. """ def __init__(self, im_channels: int = 1, # Grayscale down_ch: list = [32, 64, 128], #down_ch: list = [32, 64, 128, 256], mid_ch: list = [128, 128], #mid_ch: list = [256, 256, 128], up_ch: list[int] = [128, 64, 16], #[256, 128, 64, 16], down_sample: list[bool] = [True, True], #down_sample: list[bool] = [True, True, False], t_emb_dim: int = 128, num_downc_layers:int = 1, #2 num_midc_layers:int = 1, #2 num_upc_layers:int = 1 #2 ): super(Unet, self).__init__() self.im_channels = im_channels self.down_ch = down_ch self.mid_ch = mid_ch self.up_ch = up_ch self.t_emb_dim = t_emb_dim self.down_sample = down_sample self.num_downc_layers = num_downc_layers self.num_midc_layers = num_midc_layers self.num_upc_layers = num_upc_layers self.up_sample = list(reversed(self.down_sample)) # [False, True, True] # Initial Convolution (1=>32) self.cv1 = nn.Conv2d(self.im_channels, self.down_ch[0], kernel_size=3, padding=1) #Faltung die die Zahl der Bildkanäle auf 32 vergrößert # Initial Time Embedding Projection self.t_proj = nn.Sequential( nn.Linear(self.t_emb_dim, self.t_emb_dim), #linear function dim 128 => dim 128 nn.SiLU(), #activation nn.Linear(self.t_emb_dim, self.t_emb_dim) #linear function ) # DownC Block self.downs = nn.ModuleList([ DownC( self.down_ch[i], self.down_ch[i+1], # Zahl der Bildkanäle wird schrittweise auf 64 => 128 => 256 vergrößert self.t_emb_dim, #128 self.num_downc_layers, #2 self.down_sample[i] # wenn down_sampling => false => false ) for i in range(len(self.down_ch) - 1) ]) # MidC Block self.mids = nn.ModuleList([ MidC( self.mid_ch[i], self.mid_ch[i+1], self.t_emb_dim, self.num_midc_layers ) for i in range(len(self.mid_ch) - 1) ]) # UpC Block self.ups = nn.ModuleList([ UpC( self.up_ch[i], self.up_ch[i+1], self.t_emb_dim, self.num_upc_layers, self.up_sample[i] ) for i in range(len(self.up_ch) - 1) ]) # Final Convolution self.cv2 = nn.Sequential( nn.GroupNorm(8, self.up_ch[-1]), nn.Conv2d(self.up_ch[-1], self.im_channels, kernel_size=3, padding=1) ) def forward(self, x, t): # wird beim Aufruf von Unet automatisch ausgeführt out = self.cv1(x) # vergrößert die Zahl der Bildkanäle auf 32
# Time Projection t_emb = get_time_embedding(t, self.t_emb_dim) #(128,128) Tensor bestehend aus sin und cos Werten von t bei unterschiedlichen Frequenzen t_emb = self.t_proj(t_emb) # ein batch von 128 time_embedings der Länge 128 wird (parallel) durch einen Linear Layer (128 -> 128) geschickt, aktiviert und erneut durch einen LL (128->128) geschickt # DownC outputs down_outs = [] for down in self.downs: #2x down_outs.append(out) out = down(out, t_emb) # MidC outputs for mid in self.mids: out = mid(out, t_emb) # UpC Blocks for up in self.ups: down_out = down_outs.pop() out = up(out, down_out, t_emb) # Final Conv out = self.cv2(out) return out
Damit ist das Netzwerk endlich definiert und wir können mit dem Training beginnen. Ich möchte für das Training den MNIST Datensatz aus Graustufenbildern der Größe 28 x 28 verwenden.
Zunächst definieren wir, wie wir unseren Trainingsdatensatz einlesen wollen.
import pandas as pd import numpy as np import torchvision from torch.utils.data.dataset import Dataset from PIL import Image class CustomMnistDataset(Dataset): """ Reads the MNIST data from csv file given file path. """ def __init__(self, csv_path, num_datapoints = None): super(CustomMnistDataset, self).__init__() self.df = pd.read_csv(csv_path) if num_datapoints is not None: self.df = self.df.iloc[0:num_datapoints] def __len__(self): return len(self.df) def __getitem__(self, index): # Read img = self.df.iloc[index].filter(regex='pixel').values img = np.reshape(img, (28, 28)).astype(np.uint8) # Convert to Tensor img_tensor = torchvision.transforms.ToTensor()(img) # [0, 1] img_tensor = 2*img_tensor - 1 # [-1, 1] return img_tensor
Dann werden die Metadaten gesetzt,
class CONFIG: model_path = 'ddpm_unet.pth' train_csv_path = '/kaggle/input/train-mod30k/train_mod30k.csv' test_csv_path = '/kaggle/input/digit-recognizer/test.csv' generated_csv_path = 'mnist_generated_data.csv' num_epochs = 5 #50 lr = 1e-4 num_timesteps = 1000 batch_size = 128 img_size = 28 #28 in_channels = 1 num_img_to_generate = 5 #256
und anschließend der Trainingsloop definiert.
from torch.utils.data import DataLoader from tqdm import tqdm def train(cfg): # Dataset and Dataloader mnist_ds = CustomMnistDataset(cfg.train_csv_path) mnist_dl = DataLoader(mnist_ds, cfg.batch_size, shuffle=True) # Device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # CUDA (Compute Unified Device Architecture) ist eine von Nvidia entwickelte Programmierschnittstelle (API), mit der Programmteile durch den Grafikprozessor (GPU) abgearbeitet werden können. print(f'Device: {device}\n') # Initiate Model model = Unet().to(device) # Initialize Optimizer and Loss Function optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr) # Adam-Optimieren passst die Lernrate automatisch an. model.parameters() enthält alle lernbaren Gewichte criterion = torch.nn.MSELoss() #berechnet den Mean SquaredError also den mittleren Quadratischen Fehler zwischen zwei Tensoren (siehe Zeile 55) # Diffusion Forward Process to add noise dfp = DiffusionForwardProcess() # Best Loss best_eval_loss = float('inf') # Train for epoch in range(cfg.num_epochs): # For Loss Tracking losses = [] # Set model to train mode model.train() # Loop over dataloader for imgs in tqdm(mnist_dl): imgs = imgs.to(device) #(128,1,28,28), batch_size in CONFIG ist 128, die bilder sind Graustufenbilder der Größe 28 x 28 # Generate noise and timestamps noise = torch.randn_like(imgs).to(device) #für jeds Bild wird ein anderes epsilon (Rauschen) verwendet t = torch.randint(0, cfg.num_timesteps, (imgs.shape[0],)).to(device) #batch_size=imgs.shape[0]=128. Jedes Bild wird bis zu einem anderen Zeitpunkt t verrauscht. # Add noise to the images using Forward Process noisy_imgs = dfp.add_noise(imgs, noise, t) #erzeuge 128 verrauschte Bilder zum unterschiedlichen Zeitpunkten t (Es werden nicht alle Zeitpunkte für jedes Bild zum Training herangezogen) # Avoid Gradient Accumulation optimizer.zero_grad() #alte Gradienten löschen # Predict noise using U-net Model noise_pred = model(noisy_imgs, t) # Calculate Loss loss = criterion(noise_pred, noise) #berechne den mittleren quadratischen Fehler (MSE) zwischen dem Rauschen und dem vorhergesagten Rauschen losses.append(loss.item()) # Backprop + Update model params loss.backward() # Gradienten berechnen optimizer.step() # Modellparameter des U-Netzes anpassen # Mean Loss mean_epoch_loss = np.mean(losses) #mittlerer Loss über alle Bilder # Display print('Epoch:{} | Loss : {:.4f}'.format( epoch + 1, mean_epoch_loss, )) # Save based on train-loss if mean_epoch_loss < best_eval_loss: best_eval_loss = mean_epoch_loss torch.save(model, cfg.model_path) print(f'Done training.....')
Dann startet das Training.
# Config cfg = CONFIG() # TRAIN train(cfg)
Ist das Training beendet, und das Model abgespeichert, können neue Bilder generiert werden.
def generate(cfg): """ Given Pretrained DDPM U-net model, Generate Real-life Images from noise by going backward step by step. i.e., Mapping of Random Noise to Real-life images. """ # Device setzen device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Diffusion Reverse Process initiieren drp = DiffusionReverseProcess() # Model in Evaluierungs Modus versetzten model = torch.load(cfg.model_path, weights_only=False).to(device) #`weights_only=True` model.eval() # Ein verrauschtes Bild mit N(0, 1) generieren xt = torch.randn(1, cfg.in_channels, cfg.img_size, cfg.img_size).to(device) # Schritt für Schritt entrauschen with torch.no_grad(): # Keine Berechnung von Gradienten bei der Bildgenerierung (Inferenz)) for t in reversed(range(cfg.num_timesteps)): noise_pred = model(xt, torch.as_tensor(t).unsqueeze(0).to(device)) # Wandelt t in einen Tensor um, falls es z. B. ein einfacher Integer ist.unsqueeze(0): Fügt eine zusätzliche Dimension hinzu, sodass t die Form (1,) hat – passend zur Batch-Größe. xt, x0 = drp.sample_prev_timestep(xt, noise_pred, torch.as_tensor(t).to(device)) # Gibt x_(t-1) und x0 zurück # Bild skalieren xt = torch.clamp(xt, -1., 1.).detach().cpu() #Begrenzung der Werte im Tensor xt auf den Bereich von −1 bis +1.detach() trennt den Tensor vom Berechnungsgraphen. Das bedeutet: Es werden keine Gradienten mehr berechnet, und der Tensor ist nun „autonom“.cpu(): Verschiebt den Tensor vom GPU-Speicher zurück in den CPU-Speicher, damit er z. B. mit NumPy weiterbearbeitet werden kann. xt = (xt + 1) / 2 return xt # gibt das 'entrauschte' Bild zurück
Auch die erzeugten Bilder werden abgespeichert, und anschließend dargestellt. Aus Zeitgründen habe ich in der Regel nur 5 Bilder generieren lassen um die Qualität der Ergebnisse bewerten zu können.
# Model laden und konfigurieren cfg = CONFIG() # Bild generieren generated_imgs = [] for i in tqdm(range(cfg.num_img_to_generate)): xt = generate(cfg) xt = 255 * xt[0][0].numpy() generated_imgs.append(xt.astype(np.uint8).flatten()) # Erzeugte Daten speichern generated_df = pd.DataFrame(generated_imgs, columns=[f'pixel{i}' for i in range(784)]) generated_df.to_csv(cfg.generated_csv_path, index=False) # Visualisieren from matplotlib import pyplot as plt fig, axes = plt.subplots(1, 5, figsize=(5, 5)) # Jedes Bild in einem Subplot darstellen for i, ax in enumerate(axes.flat): ax.imshow(np.reshape(generated_imgs[i], (28, 28)), cmap='gray') plt.tight_layout() # Platz zwischen Subplots anpassen plt.show()
Nach nur 5 Trainingsdurchläufen mit einer Laufzeit von je 2 Minuten und 25 Sekunden und unter Verwendung von 42 Tausend Bildern, habe ich folgende Ergebnisse erhalten.
Dies stellt bereits eine deutliche Verbesserung gegenüber dem 'ersten Versuch ein Denoising Probalistic Model zu programmieren' dar. Bei Verwendung von 30 Tausend Bildern reduziert sich die Laufzeit pro 'Epoche' auf 1 Minute und 40 Sekunden. Die Qualität nimmt ab, ist aber immer noch akzeptabel.
Es sei darauf hingewiesen, dass sich diese schnellen Laufzeiten nur unter Verwendung einer GPU (hier einer NVIDIA Tesla P100 ) erzielen lassen. Das gesamte Programm lässt sich unterdiesem link als Kaggle Notebook ausführen. Im Notebook finden sich auch links auf weiterführende Informationen.
Letztlich habe ich das Netzwerk noch mit 40 Durchläufen trainiert. Der Verlauf des 'Losses' ist unten gezeigt.
Das Ergebnis

 
Keine Kommentare:
Kommentar veröffentlichen