Mittwoch, 29. Oktober 2025

Zweiter Versuch ein Denoising Probalistic Model zu programmieren

Im zweiten Versuch ein Denoising Diffusion Probabilistic Model zu programmieren folge ich einem Artikel von Jonathan Ho, Ajay Jain und Pieter Abbeel zu diesem Thema. Wie bereits im ersten Versuch erläutert, wird im Forward-Prozess schrittweise Rauschen zu einem Eingabebild \( x_0 \) hinzugefügt. Das Rauschen hat eine zeitabhängige Normalverteilung, wobei ein Parameter \( \beta_t \) die Intensität des Rauschens steuert. 


Die Verteilung der 'Farbwerte' von \(x_t\) hängt nur von \(x_{t-1}\) ab. Es gilt:

$$q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} \cdot x_{t-1}, \beta_t \cdot I) $$

Sie besitzt einen Mittelwert von \(\sqrt{1 - \beta_t} \cdot x_{t-1}\) und  eine Varianz von \(\beta_t \cdot I\) . Dabei ist \(I\) die Einheitsmatrix, weil die Störung gleichmäßig in jede Richtung erfolgen soll. Für alle Zeitschritte \( t \) gilt \( \beta_t < 1 \) . Für die Konfiguration des Forward-Prozesses habe ich eine Anzahl von \( T\) = 1000 Zeitschritten gewählt, wobei die Rauschparameter linear von \( \beta_1 = 10^{-4} \) bis \( \beta_T\) = 0,02 ansteigen. Die lineare Zeitplanung erfolgt nach der Formel

$$ \beta_t = \beta_1 + \frac{(t - 1)}{T} \cdot \beta_T $$

Für ein gegebenes Bild \( x_0 \) und die festgelegten \( \beta_t \) lassen sich alle Zwischenzustände \( x_t \) direkt berechnen, ohne dass man rekursiv über alle vorherigen Schritte gehen muss. Die entsprechende Verteilung lautet hierfür

$$ q(x_t | x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t} \cdot x_0, (1 - \bar{\alpha}_t) \cdot I) $$

wobei

$$ \alpha_t = 1 - \beta_t $$

und 

$$ \bar{\alpha}_t =\prod_{s=1}^{t} \alpha_s $$

gilt. Für \( x_t \) selbst gilt in Abhängigkeit von \( x_0 \): 

$$ x_t = \sqrt{\bar{\alpha}_t} \cdot x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon $$

\( \epsilon \sim \mathcal{N}(0, I) \) ist ein Zufallsrauschen. Der gesamte Vorwärts-Prozess enthält also  keine trainierbaren Parameter.  Er dient lediglich dazu, das ursprüngliche Bild systematisch in Rauschen zu überführen und bildet die Grundlage für das spätere Lernen im Reverse-Prozess. Im Programm wird der Vorwärtsprozess in der Klasse 'DiffusionForwardProcess' wie folgt definiert:

# Importe
import torch
import torch.nn as nn

class DiffusionForwardProcess:
    
    """
    Forward process as described in the paper “Denoising Diffusion Probabilistic Models".
    """
    
    def __init__(self, 
                 num_time_steps = 1000, 
                 beta_start = 1e-4, 
                 beta_end = 0.02
                ):
        
        # Vorberechnung von beta, alpha und alpha_bar für alle Zeitschritte t
        self.betas = torch.linspace(beta_start, beta_end, num_time_steps)  # 1D-Tensor mit gleichmäßig verteilten Werten zwischen beta_start und beta_end
        self.alphas = 1 - self.betas  # alpha = 1 - beta
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)  # Kumulatives Produkt: z. B. [α₁, α₁·α₂, α₁·α₂·α₃, ...]
        self.sqrt_alpha_bars = torch.sqrt(self.alpha_bars)  # Quadratwurzel von alpha_bar
        self.sqrt_one_minus_alpha_bars = torch.sqrt(1 - self.alpha_bars)  # Quadratwurzel von (1 - alpha_bar)
        
    def add_noise(self, original, noise, t):  # Hinweis: alle Argumente müssen Tensoren sein
        
        """
        Fügt einem Batch von Originalbildern zum Zeitpunkt t Rauschen hinzu.
        :parameter original: Eingabebild als Tensor
        :parameter noise: Zufallsrauschen aus der Normalverteilung N(0, 1)
        :parameter t: Zeitschritt des Forward-Prozesses, shape: (B,)
        Hinweis: Der Zeitschritt t kann für jedes Bild im Batch unterschiedlich sein.
        """
        
        sqrt_alpha_bar_t = self.sqrt_alpha_bars.to(original.device)[t]  # Auf GPU übertragen, falls das Originalbild auf der GPU liegt
        sqrt_one_minus_alpha_bar_t = self.sqrt_one_minus_alpha_bars.to(original.device)[t]
        
        # Broadcasting zur Multiplikation mit dem Originalbild
        sqrt_alpha_bar_t = sqrt_alpha_bar_t[:, None, None, None]  # Zusätzliche Dimensionen für (B, C, H, W)
        sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t[:, None, None, None]

        # Rückgabe des verrauschten Bildes
        return (sqrt_alpha_bar_t * original) + (sqrt_one_minus_alpha_bar_t * noise)  
        # Jeder Pixel des Originalbildes wird mit sqrt_alpha_bar_t multipliziert und das Rauschen addiert.
        # Falls t z. B. vier Zeitschritte enthält, werden vier Bilder mit entsprechendem Rauschen erzeugt.

Untenstehender Testaufruf gibt 4 unterschiedliche Bilder zu 4 unterschiedlichen Zeitpunkten zurück.

# Test
original = torch.randn(4,1, 28, 28) # batch aus 4 Bildern mit Kanalzahl 1, Höhe 28 und Breite 28
noise = torch.randn(1, 28, 28)
t_steps = torch.randint(0, 1000, (4,)) #1D Tensor aus 4 Zufallszahlen zwischen 0 und 1000

# Forward Process
dfp = DiffusionForwardProcess() #Instanz der Klasse erzeugen
out = dfp.add_noise(original, noise, t_steps) #gibt 4 unterschiedliche Bildern zu 4 unterschiedlichen Zeitpunkten zurück 
print(out.shape)

Die Dimensionen des verrauschten Bildbatches sind mit (4,1,28,28) identisch zum ursprünglichen batch.

Der Reverse-Prozess eines DDPM lässt sich als die Umkehrung des Forward-Prozesses verstehen. Ziel ist es, ein vollständig verrauschtes Bild \( x_T \sim \mathcal{N}(0, I) \) über \( T \) Zeitschritte hinweg schrittweise zu entrauschen, bis das ursprüngliche Bild rekonstruiert ist. Während der Forward-Prozess deterministisch ist und keine trainierbaren Parameter enthält, wird im Reverse-Prozess typischerweise ein U-Net eingesetzt, um das Rauschen \( \epsilon_\theta \) zu jedem Zeitpunkt zu schätzen. Die Rekonstruktion von \( x_{t-1} \) aus \( x_t \) erfolgt über die Gleichung

$$ x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \cdot \epsilon_{\theta} \right) + \sigma_t z $$

wobei \( z \sim \mathcal{N}(0, I) \) ein zusätzlicher Zufallsvektor ist, der die stochastische Natur der Diffusion bewahrt. Zusätzlich lässt sich zu jedem Zeitpunkt eine Schätzung des ursprünglichen Bildes \( x_0 \) berechnen, und zwar durch

$$ x_0 = \frac{1}{\sqrt{\bar{\alpha}_t}} \left( x_t - \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon_{\theta} \right) $$

Diese Formeln bilden die Grundlage des Reverse-Prozesses, bei dem das Modell lernt, aus reinem Rauschen ein realistisches Bild zu generieren. Die Qualität der Rekonstruktion hängt dabei entscheidend von der Güte der Rauschvorhersage \( \epsilon_{\theta} \) ab, die wir im nächsten Schritt durch ein trainierbares U-Net-Modell implementieren werden. 

Der Rückwärtsprozess wird in der unterstehenden Klasse 'DiffusionReverseProcess' definiert.

class DiffusionReverseProcess:
    
    """
    Reverse process as described in the paper “Denoising Diffusion Probabilistic Models".
    """
    
    def __init__(self, 
                 num_time_steps = 1000, 
                 beta_start = 1e-4, 
                 beta_end = 0.02
                ):
        
        # Vorberechnung von beta, alpha und alpha_bar für alle Zeitschritte t
        self.b = torch.linspace(beta_start, beta_end, num_time_steps)  # b → beta
        self.a = 1 - self.b  # a → alpha
        self.a_bar = torch.cumprod(self.a, dim=0)  # a_bar = alpha_bar (kumulatives Produkt)
        
    def sample_prev_timestep(self, xt, noise_pred, t):
        
        """
        Berechnet x_{t-1} aus x_t und dem vom Modell vorhergesagten Rauschen.
        :param xt: Bildtensor zum Zeitpunkt t, Form → B x C x H x W
        :param noise_pred: Vom Modell vorhergesagtes Rauschen, gleiche Form wie xt
        :param t: Aktueller Zeitschritt
        """
        
        # Schätzung des ursprünglichen Bildes x_0 zum Zeitpunkt t
        x0 = xt - (torch.sqrt(1 - self.a_bar.to(xt.device)[t]) * noise_pred)
        x0 = x0 / torch.sqrt(self.a_bar.to(xt.device)[t])
        x0 = torch.clamp(x0, -1., 1.)  # Begrenze Werte von x0 auf den Bereich [-1, 1]
        
        # Berechnung des Erwartungswerts (Mittelwert) von x_{t-1}
        mean = (xt - ((1 - self.a.to(xt.device)[t]) * noise_pred) / torch.sqrt(1 - self.a_bar.to(xt.device)[t]))
        mean = mean / torch.sqrt(self.a.to(xt.device)[t])
        
        # Falls t = 0, gib nur den Mittelwert zurück (kein Sampling mehr nötig)
        if t == 0:
            return mean, x0
        
        else:
            # Berechnung der Varianz für den Übergang von t zu t-1
            variance = (1 - self.a_bar.to(xt.device)[t-1]) / (1 - self.a_bar.to(xt.device)[t])  # Verhältnis der Rauschanteile zwischen zwei aufeinanderfolgenden Zeitpunkten
            variance = variance * self.b.to(xt.device)[t]  # beta ist die Varianz des Forward-Prozesses im Schritt t, also wie viel Rauschen hinzugefügt wurde
            sigma = variance**0.5
            z = torch.randn(xt.shape).to(xt.device)  # Zufallsrauschen z ~ N(0, I)
            
            # Rückgabe von x_{t-1} (gesampelt) und der Schätzung von x_0
            return mean + sigma * z, x0

Ein Testaufruf mit einem zum Zeitpunkt \(t\) verrauschten Bild der Dimension (1, 1, 28, 28) , liefert das Bild zum Zeitpunkt \(t-1\) und das ursprüngliche Bild mit der Dimension (1, 1, 28, 28) zurück.

# Test
image = torch.randn(1, 1, 28, 28)
noise_pred = torch.randn(1, 1, 28, 28)
t = torch.randint(0, 1000, (1,)) 

# Reverse Process
drp = DiffusionReverseProcess()
out, x0 = drp.sample_prev_timestep(image, noise_pred, t)
print(out.shape)
print(x0.shape)

In Diffusionsmodellen ist die Zeitinformation von zentraler Bedeutung, da das Modell zu jedem Zeitschritt wissen muss, in welcher Phase der Rauschvorhersage es sich befindet. Um diese zeitliche Information explizit zu codieren, greift man auf eine Technik zurück, die ursprünglich aus dem Positional Encoding in Transformer-Modellen stammt. Dabei wird der Zeitschritt \( t \) in einen kontinuierlichen Vektorraum eingebettet, sodass das Modell ihn als Eingabe verarbeiten kann. Die Einbettung erfolgt über sinus- und cosinusbasierte Funktionen, die für jede Dimension des Embedding-Vektors eine unterschiedliche Frequenz verwenden. Formal wird die Einbettung \(PE\) von \(0\) bis \(d_{\text{emb}} / 2\) bzw. von \(d_{\text{emb}} / 2\) bis \(d_{\text{emb}}\) wie folgt berechnet:

$$ PE(pos, 0:d_{\text{emb}} / 2) = \sin\left(\frac{pos}{10000^{2i / d_{\text{emb}}}}\right)$$

$$ PE(pos, d_{\text{emb}} / 2:d_{\text{emb}}) = \cos\left(\frac{pos}{10000^{2i / d_{\text{emb}}}}\right) $$

Dabei bezeichnet \( pos \) den Zeitschritt \( t \), \( d_{\text{emb}} \) die Länge des Embedding-Vektors (z. B. 128), und \( i \) den Index innerhalb der jeweiligen Hälfte des Vektors. Der Ausdruck \( 10000^{2i / d_{\text{emb}}} \) definiert die Periodenbasis und sorgt dafür, dass jede Dimension eine andere Frequenz besitzt – so entsteht eine vielfältige, kontinuierliche Repräsentation der Zeit. Diese zeitliche Einbettung wird anschließend in das neuronale Netz eingespeist. Die 'get_time_embedding' Funktion ist folgendermaßen definiert:

def get_time_embedding(
    time_steps: torch.Tensor,
    t_emb_dim: int
) -> torch.Tensor:
    
    # """ Definiert die Funktion 'get_time_embedding, die zwei Parameter erwartet: 
    # 1. time_steps: einen 1D torch Tensor mit B vielen Zeitpunkten (z. B. [0, 1, 2, ...])
    # 2. t_emb_dim: die gewünschte Länge des Embeddings (z. B. 128)
    # Gibt einen torch Tensor mit B vielen Zeilen und t_emb_dim Spalten zurück
    # """
    
    assert t_emb_dim%2 == 0, "time embedding must be divisible by 2." #Fehlermeldung wenn t_emb_dim nicht durch 2 ohne Rest teilbar ist
    
    factor = 2 * torch.arange(start = 0, end = t_emb_dim//2, dtype=torch.float32, device=time_steps.device) / (t_emb_dim) #Füllt einen t_emb_dim//2 langen torch Tensor mit Werten von 0 bis t_emb_dim//2-1 skaliert auf 2/t_emb_dim
 
    factor = 10000**factor #factor = 10000^factor. Der Faktor ist immer < 1. Die Faktoren fungieren als Frequenzen für die Zeitschritte. 'Spätere' time_steps bekommen kleinere Frequenzen

    t_emb = time_steps[:,None] # B -> (B, 1), erzeugt eine Spalte
    
    t_emb = t_emb/factor # (B, 1) -> (B, t_emb_dim//2) #erzeugt eine Matrix mit time_step_0/factor_0 bis time_step_0/factor_(t_emb_dim//2-1)in der ersten Zeile und time_step_(B-1)//factor_0 bis time_step_(B-1)/factor_(t_emb_dim//2-1) in der letzen Zeile

    t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=1) # (B , t_emb_dim). Sinus und Cosinus wird berechnet. Die Ergebnisse werden entlang der zweiten Achse (dim=1) zusammengefügt d.h. Spalten werden angehängt). Ergebnis: Ein Tensor der Form (B, t_emb_dim) hier (128,128) – die eine Hälfte Sinus, die andere Hälfte Kosinus.

    return t_emb

Die zentrale Aufgabe des Modells besteht darin, den Rauschanteil eines Bildes zu verschiedenen Zeitpunkten vorherzusagen. Da sowohl Eingabe als auch Ausgabe zweidimensionale Bilder sind, eignet sich eine U-Net Architektur für diese Art von Bild-zu-Bild-Übersetzung. Es besteht aus drei Hauptkomponenten:

1. Downsampling Einheit oder auch Encoder genannt: Hier werden die räumlichen Dimensionen durch aufeinanderfolgende Convolution-Blöcke und Pooling-Operationen reduziert, während die Anzahl der Feature-Kanäle steigt. Dadurch werden die relevanten Merkmale extrahiert.

2. Bottleneck Einheit oder Mittlerer Block: Dies ist die tiefste Ebene des Netzwerks mit der abstraktesten Repräsentation. Es kommen Residual- oder Attention-Mechanismen zum Einsatz, die später noch genauer erklärt werden.

3. Upsampling Einheit auch Decoder genannt: Über Transponierte Convolution oder Interpolation wird die ursprüngliche räumliche Auflösung schrittweise wiederhergestellt. Durch sogenannte 'Skip Connections' zwischen Encoder und Decoder werden Details direkt übertragen und ermöglichen eine präzise Rekonstruktion.

Konkret wird die Eingabeschicht unseres U-Netzwerks also z.B. mit einem verrauschten Bild vom shape (1,1,28,28) zum Zeitpunkt t gefüttert, und die letzte Schicht liefert ein 'vorhergesagtes' Rauschen \(\epsilon_{\Theta}\) vom gleichem shape (1,1,28,28). 

Die ASCII Grafik unten zeigt die verwendete UNet Struktur. Abkürzungen stehen für Hilfsmodule die im folgenden genauer erklärt werden. 

initial_conv +

Down = 3x(DownC.forward) mit num_DownC_layers =2:

    (conv1 + te + conv2 + attn + dspl) + (conv1 + te + conv2 + attn + dspl) ------------------------
     |----- skip -----|                   |----- skip -----|                                        |
                                                                                                    |
+   (conv1 + te + conv2 + attn + dspl) + (conv1 + te + conv2 + attn + dspl) ---------------------   |
         |----- skip -----|                   |----- skip -----|                                 |  |
                                                                                                 |  |
+   (conv1 + te + conv2 + attn       ) + (conv1 + te + conv2 + attn       ) ------------------   |  |
     |----- skip -----|                   |----- skip -----|                                  |  |  |
                                                                                              |  |  |
                                                                                              |  |  |
Mid = 3x(MidC.forward) mit num_MidC_layers=2:                                                 |  |  |
                                                                                              |  |  |
+    conv1 + te + conv2 + (attn + conv1 + te + conv2) + (attn + conv1 + te + conv2)           |  |  |
     |----- skip -----|           |----- skip -----|            |----- skip -----|            |  |  |
                                                                                              |  |  |
+    conv1 + te + conv2 + (attn + conv1 + te + conv2) + (attn + conv1 + te + conv2)           |  |  |
     |----- skip -----|           |----- skip -----|            |----- skip -----|            |  |  |
                                                                                              |  |  |
+    conv1 + te + conv2 + (attn + conv1 + te + conv2) + (attn + conv1 + te + conv2)           |  |  |
     |----- skip -----|           |----- skip -----|            |----- skip -----|            |  |  |
                                                                                              |  |  |
                                                                                              |  |  |
Up = 3x(UpC.forward) mit num_UpC_layers=2:                                                    |  |  |
                                                                                              |  |  |
            |------------------------------------------------------------------- skip --------   |  |
+   (       conv1 + te + conv2 + attn) + (       conv1 + te + conv2 + attn)                      |  |
            |----- skip -----|                   |----- skip -----|                              |  |
                                                                                                 |  |
            |------------------------------------------------------------------- skip ----------    |
+   (uspl + conv1 + te + conv2 + attn) + (uspl + conv1 + te + conv2 + attn)                         |
            |----- skip -----|                   |----- skip -----|                                 |
                                                                                                    |
            |------------------------------------------------------------------- skip --------------
+   (uspl + conv1 + te + conv2 + attn) + (uspl + conv1 + te + conv2 + attn)
            |----- skip -----|                   |----- skip -----|  

+   final_conv

'Conv(olution)' steht dabei für die Klasse 'NormActConv'. Eine Instanz dieser Klasse führt eine Normierung, anschließend eine Aktivierung und dann eine Faltung aus. Die Normierung erfolgt mit 'GroupNorm'. GroupNorm teilt die Kanäle in mehrere Gruppen (z. B. 8 Gruppen bei num_groups=8) und normalisiert jeden Kanal innerhalb seiner Gruppe. Das bedeutet: Für jede Gruppe wird  der Mittelwert \(\mu_G\) und die Standardabweichung \(\sigma_G\) berechnet. Dann wird jeder Wert in der Gruppe wie folgt transformiert: \(x'=(x−\mu_G)/(\sigma_G+\epsilon)\). Dabei ist \(\epsilon\) eine kleine Zahl zur Vermeidung von Division durch Null.

class NormActConv(nn.Module):
    """
    Performs Normalization, Activation and Convolution   
    """
    def __init__(self, 
                 in_channels:int, # Anzahl der Eingangskanäle
                 out_channels:int, # Anzahl der Ausgangskanäle
                 num_groups:int = 8, # Anzahl der Gruppen für GroupNorm (default = 8)
                 kernel_size: int = 3, # Kernelgröße für Conv2D (default = 3)
                 norm:bool = True, # GroupNorm wenn gesetzt
                 act:bool = True # SiLU Aktivierung x * sigmoid(x) wenn gesetzt
                ):

        super(NormActConv, self).__init__() 
        # GroupNorm
        self.g_norm = nn.GroupNorm(num_groups,in_channels) if norm is True else nn.Identity()
        
        # Activation
        self.act = nn.SiLU() if act is True else nn.Identity()
        
        # Convolution
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size, 
            padding=(kernel_size - 1)//2 #Padding wird so gewählt, dass die räumliche Größe erhalten bleibt
        )
        
    def forward(self, x):
        x = self.g_norm(x)
        x = self.act(x)
        x = self.conv(x)
        return x

'te' steht für 'TimeEmbeding', und beschreibt wie die Zeit-Einbettung ins Modell integriert wird. Zuerst wird die Eingabe mit der SiLU-Aktivierung geglättet, anschließend durch eine lineare Schicht in die gewünschte Ausgabedimension projiziert. Dazu wird ein linearer layer mit \(x' = x\cdot A + b\) eingefügt, wobei \(x\) die Eingabe mit t_emb_dim Dimensionen, \(x'\) die Ausgabe in n_out Dimensionen, \(A\) wie üblich eine zu trainierende Gewichtsmatrix und \(b\) der Bias ist.

class TimeEmbedding(nn.Module):
    """
    Maps the Time Embedding to the Required output Dimension.
    """
    def __init__(self, 
                 n_out:int, # Output Dimension
                 t_emb_dim:int = 128 # Time Embedding Dimension
                ):
        super(TimeEmbedding, self).__init__()
        
        # Time Embedding Block
        self.te_block = nn.Sequential(
            nn.SiLU(),  # Aktivierung
            nn.Linear(t_emb_dim, n_out) # 'embeding' durch Einfügen einer linearen Schicht
        )
    def forward(self, x):
        return self.te_block(x)

'attn' steht für  'SelfAttention', und erlaubt es dem Modell, nicht-lokale Beziehungen zu lernen, also wie z. B. ein Objekt in der oberen linken Ecke eines Bildes mit einem Objekt unten rechts zusammenhängt.  Eine einzelne Aufmersamkeitsfunktion besteht aus einer Abfrage \(Q\), einem Schlüssel \(K\) und einem Wert \(V\).  Der Schlüssel \(K\) stellt dabei eine Beziehungsmatrix der Bildausschnitte, oder von Wörtern zueinander dar. Der Wert \(V\) repräsentiert die Bedeutung eines Wortes bzw. den Wertvektor eines Bildausschnittes. Die Selbstaufmerksamkeit wird dann wie folgt berechnet:

$$SelfAttention(Q,K,V) = Softmax(QK^T/\sqrt{d_k})V$$

 Das Produkt \(QK^T\) misst, wie ähnlich das aktuelle Wort (Query) zu anderen Wörtern ist. Die Ähnlichkeitswerte werden durch die Softmax-Funktion in Wahrscheinlichkeiten umgewandelt. Damit die Werte nicht zu groß werden, teilt man durch \(\sqrt{𝑑_𝑘}\), wobei \(d_k\) die Dimension der Key-Vektoren ist. Diese Wahrscheinlichkeiten werden auf die Value-Vektoren \(V\) angewendet. So entsteht eine gewichtete Summe über alle 'values'. Bei der multihead Self Attention Operation wird die Eingabe X unter Verwendung unterschiedlicher Gewichtungsmatrizen W in mehrere kleinerdimensionale Unterräume projiziert.

$$Q_i = X\cdot W^Q_i, K_i = X\cdot W^K_i, V_i = X\cdot W^V_i $$

wobei \(i\) den Kopfindex bezeichnet. Jeder Kopf berechnet unabhängig voneinander seine eigene Selbstaufmerksamkeit mit der skalierten Punktproduktformel. Die Ausgänge aller Köpfe werden anschließend verkettet.

class SelfAttentionBlock(nn.Module):
    """
    Perform GroupNorm and Multiheaded Self Attention operation.    
    """
    def __init__(self, 
                 num_channels:int,
                 num_groups:int = 8, 
                 num_heads:int = 4,
                 norm:bool = True
                ):
        super(SelfAttentionBlock, self).__init__()
        
        # GroupNorm
        self.g_norm = nn.GroupNorm(num_groups,num_channels) if norm is True else nn.Identity()
        
        # Self-Attention
        self.attn = nn.MultiheadAttention(num_channels,num_heads, batch_first=True) #batch_first=True bedeutet, dass die Eingabe die Form (B, SeqLen, C) hat.
        
    def forward(self, x):
        batch_size, channels, h, w = x.shape
        
        x = x.reshape(batch_size, channels, h*w)
        # """
        # Wandelt das 'Bild' oder ein batch von Punkten oder 'Token' in eine Sequenz der Dimension 
        # (B, C, H*W). Jeder räumliche Punkt bei einem Bild (zB. die Kanäle R, G und B' mit 
        # Dimension C=3) wird als ein 'Token' oder Textstück betrachtet.
        # """
        x = self.g_norm(x) # Normalisiert die Kanäle gruppenweise
        x = x.transpose(1, 2) # neue Form: (B, H*W, C) – notwendig für MultiheadAttention
        x, _ = self.attn(x, x, x) # Führt die eigentliche MultiHeadSelfAttention durch. Query, Key und Value input sind gleich. Die zurückgegebene Gewichtungsmatrix wird ignoriert.
        x = x.transpose(1, 2).reshape(batch_size, channels, h, w) # Transponiert zurück zu (B, C, H*W) und formt dann zurück zu (B, C, H, W).
    
        return x
    

'dspl' steht für 'Downsample', und reduziert die räumliche Auflösung der 'Feature-Maps' um den Faktor \(k\). Die Klasse kombiniert 'strided Convolution' und 'Max-Pooling' und passt die Kanalanzahl an.

class Downsample(nn.Module):
    """
    Perform Downsampling by the factor of k across Height and Width.
    """
    def __init__(self, 
                 in_channels:int, 
                 out_channels:int, 
                 k:int = 2, # Downsampling Factor
                 use_conv:bool = True, # If Downsampling using conv-block
                 use_mpool:bool = True # If Downsampling using max-pool
                ):
        super(Downsample, self).__init__()
        
        self.use_conv = use_conv
        self.use_mpool = use_mpool
        
        # Downsampling using Convolution
        self.cv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=1), #Eine 1×1-Conv, der die Kanäle nicht verändert, aber als Feature-Transformation dient.
            nn.Conv2d(in_channels, out_channels//2 if use_mpool else out_channels, kernel_size=4, stride=k, padding=1) #Eine 4×4-Conv mit Stride k (Filter bewegt sich in Schritten von k  Pixeln), der die räumliche Größe reduziert. Wenn use_mpool=True, wird die Anzahl der Ausgabekanäle halbiert (out_channels//2), da der max-pool Zweig auch Features liefert.
        ) if use_conv else nn.Identity()
        
        # Downsampling using Maxpool
        self.mpool = nn.Sequential(
            nn.MaxPool2d(k, k), #Downsampling durch Pooling mit Fenstergröße und Stride k.
            nn.Conv2d(in_channels, out_channels//2 if use_conv else out_channels, kernel_size=1, stride=1, padding=0) #1×1-Conv zur Kanaltransformation nach dem Pooling.
        ) if use_mpool else nn.Identity()


    def forward(self, x):
        
        if not self.use_conv: # Nur MaxPool-Zweig wird verwendet.
            return self.mpool(x)
        
        if not self.use_mpool: # Nur Convolution-Zweig wird verwendet.
            return self.cv(x)

        return torch.cat([self.cv(x), self.mpool(x)], dim=1) # cv und mpool werden entlang der Kanalachse (dim=1) zusammengefügt.

'uspl' steht für 'UpSample', und vergrößert hingegen Bildmerkmale um den Faktor \(k\). Das erfolgt per transponierter Faltung, bilinearer Interpolation oder beidem. 

class Upsample(nn.Module):
    """
    Perform Upsampling by the factor of k across Height and Width
    """
    def __init__(self, 
                 in_channels:int, 
                 out_channels:int, 
                 k:int = 2, # Upsampling Faktor
                 use_conv:bool = True, # Upsampling durch 'conv-block' (transposed convolution) wenn gesetzt
                 use_upsample:bool = True # Upsampling durch 'upsample' (bilinear upsampling) wenn gesetzt
                ):
        super(Upsample, self).__init__()
        
        self.use_conv = use_conv
        self.use_upsample = use_upsample
        
        # Upsampling using conv
        self.cv = nn.Sequential(
            nn.ConvTranspose2d(in_channels,out_channels//2 if use_upsample else out_channels, kernel_size=4, stride=k, padding=1), # Transposed Convolution fungiert als 'de convolution'. Stride=k sorgt dafür, dass die räumliche Größe um den Faktor k vergrößert wird.
            nn.Conv2d(out_channels//2 if use_upsample else out_channels, out_channels//2 if use_upsample else out_channels, kernel_size = 1, stride=1, padding=0) # Ein 1×1-Conv zur Kanalmischung nach dem Upsampling.
        ) if use_conv else nn.Identity()
        
        # Upsamling using nn.Upsample
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=k, mode = 'bilinear', align_corners=False), #Vergrößert die Höhe und Breite um den Faktor k durch Interpolation. mode='bilinear' ist eine gängige Methode für glatte Ergebnisse.
            nn.Conv2d(in_channels,out_channels//2 if use_conv else out_channels, kernel_size=1, stride=1, padding=0) # Auch hier ein 1×1-Conv zur Anpassung der Kanalanzahl.
        ) if use_upsample else nn.Identity()
        
    def forward(self, x):
        
        if not self.use_conv:
            return self.up(x) # Nur bilineares Upsampling wird verwendet.
        if not self.use_upsample:
            return self.cv(x) #  Nur Transposed Convolution wird verwendet.

        return torch.cat([self.cv(x), self.up(x)], dim=1) # Beide Methoden werden kombiniert, und die Ergebnisse entlang der Kanalachse (dim=1) zusammengefügt. Dadurch entsteht ein Feature-Map mit out_channels Kanälen. 

Damit sind alle Hilfsklassen definiert. 

Um das U-Netz modular zu gestalten werden nun die Klassen 'DownC', 'MidC' und 'UpC' definiert. `DownC` verarbeitet die Eingabeschicht durch wiederholte Convolution, Zeit-Embedding, Self-Attention und anschließendes Downsampling. Wie in einem residual network (ResNet), werden auch sogenannte 'skip connections' verwendet. Ein allgemeiner residual block in einem ResNet sieht wie folgt aus:

Über eine 'skip connection'  wird der Eingang eines Teilnetzes mit dessen Ausgang verbunden. Dadurch wird das bei tiefen Transformerstrukturen auftretende Problem des verschwindenden Gradienten abgeschwächt. In der Tat können sehr tiefe Transformerstrukturen nicht ohne 'skip connections' trainiert werden

class DownC(nn.Module):
    """
    Perform Down-convolution on the input using following general approach.
    1. Convolution
    2. TimeEmbedding
    3. Convolution
    4. Self-Attention
    5. Downsampling
    """
    def __init__(self, 
                 in_channels:int, 
                 out_channels:int, 
                 t_emb_dim:int = 128, # Time Embedding Dimension
                 num_layers:int=2,
                 down_sample:bool = True # True for Downsampling
                ):
        super(DownC, self).__init__()
        
        self.num_layers = num_layers

        # conv1: Erste Convolution vor Zeit-Embedding
        self.conv1 = nn.ModuleList([NormActConv(in_channels if i==0 else out_channels, out_channels) for i in range(num_layers)])
        # conv2: Zweite Convolution nach Addition des Embeddings
        self.conv2 = nn.ModuleList([NormActConv(out_channels, out_channels) for _ in range(num_layers)])
        # Zeit-Embedding
        self.te_block = nn.ModuleList([TimeEmbedding(out_channels, t_emb_dim) for _ in range(num_layers)])
        # Self-Attention 
        self.attn_block = nn.ModuleList([SelfAttentionBlock(out_channels) for _ in range(num_layers)])
        # optionales Downsampling
        self.down_block =Downsample(out_channels, out_channels) if down_sample else nn.Identity()
        # 1×1-Conv für die Skip-Connection     
        self.res_block = nn.ModuleList([nn.Conv2d(in_channels if i==0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers)])

    def forward(self, x, t_emb):
        
        out = x
        for i in range(self.num_layers):
            resnet_input = out # Schritt 1: Speichern des Inputs für Skip-Connection
            out = self.conv1[i](out) #Schritt 2: Erste Convolution
            out = out + self.te_block[i](t_emb)[:, :, None, None] #Schritt 3: Zeit-Embedding wird auf die räumliche Form „gebroadcasted“ (mit None erweitert) und hinzugefügt.
            out = self.conv2[i](out) # Schritt 4: Zweite Convolution
            out = out + self.res_block[i](resnet_input) # Schritt 5: Skip-Connection
            out_attn = self.attn_block[i](out) # Schritt 6: Self-Attention
            out = out + out_attn # Attention-Ergebnis wird hinzugefügt

        return out

Die Klasse `MidC` verfeinert Feature-Maps aus dem Downsampling-Pfad durch einen initialen ResNet-Block mit Zeit-Embedding, gefolgt von mehreren Self-Attention- und ResNet-Blöcken, die ebenfalls zeitliche Informationen integrieren. 

class MidC(nn.Module):
    """
    Refine the features obtained from the DownC block. It refines the features using following operations:
    1. Resnet Block with Time Embedding
    2. A Series of Self-Attention + Resnet Block with Time-Embedding 
    """
    def __init__(self, 
                 in_channels:int, 
                 out_channels:int,
                 t_emb_dim:int = 128,
                 num_layers:int = 2
                ):
        super(MidC, self).__init__() 

        self.num_layers = num_layers
        self.conv1 = nn.ModuleList([NormActConv(in_channels if i==0 else out_channels, out_channels) for i in range(num_layers + 1)])
        self.conv2 = nn.ModuleList([NormActConv(out_channels, out_channels) for _ in range(num_layers + 1)])
        self.te_block = nn.ModuleList([TimeEmbedding(out_channels, t_emb_dim) for _ in range(num_layers + 1)])
        self.attn_block = nn.ModuleList([SelfAttentionBlock(out_channels) for _ in range(num_layers)])
        self.res_block = nn.ModuleList([nn.Conv2d(in_channels if i==0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers + 1)])

    def forward(self, x, t_emb):
        out = x
        
        # Erster-Resnet Block
        resnet_input = out
        out = self.conv1[0](out)
        out = out + self.te_block[0](t_emb)[:, :, None, None]
        out = self.conv2[0](out)
        out = out + self.res_block[0](resnet_input)
        
        # Abfolge von Self-Attention + Resnet Blocks
        for i in range(self.num_layers):
            
            # Self Attention
            out_attn = self.attn_block[i](out)
            out = out + out_attn
            
            # Resnet Block
            resnet_input = out
            out = self.conv1[i+1](out)
            out = out + self.te_block[i+1](t_emb)[:, :, None, None]
            out = self.conv2[i+1](out)
            out = out + self.res_block[i+1](resnet_input)
            
        return out

Die Klasse `UpC` führt eine mehrschichtige Up-Convolution durch, kombiniert Upsampling, zeitliche Einbettung, Residualverbindungen und Selbstaufmerksamkeit, um Feature-Maps effizient zu verarbeiten und mit Encoder-Ausgaben zu verknüpfen.

class UpC(nn.Module):
    """
    Perform Up-convolution on the input using following approach.
    1. Upsampling
    2. Convolution
    3. TimeEmbedding
    4. Convolution
    6. Self-Attention
    """
    def __init__(self, in_channels:int, 
                 out_channels:int, 
                 t_emb_dim:int = 128, # Time Embedding Dimension
                 num_layers:int = 2,
                 up_sample:bool = True # Upsampling wenn Wahr
                ):
        
        super(UpC, self).__init__()
        
        self.num_layers = num_layers
        self.conv1 = nn.ModuleList([NormActConv(in_channels if i==0 else out_channels, out_channels) for i in range(num_layers)])
        self.conv2 = nn.ModuleList([NormActConv(out_channels, out_channels) for _ in range(num_layers)])
        self.te_block = nn.ModuleList([TimeEmbedding(out_channels, t_emb_dim) for _ in range(num_layers)])
        self.attn_block = nn.ModuleList([SelfAttentionBlock(out_channels) for _ in range(num_layers)])
        self.up_block =Upsample(in_channels, in_channels//2) if up_sample else nn.Identity()
        self.res_block = nn.ModuleList([nn.Conv2d(in_channels if i==0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers)])
        
    def forward(self, x, down_out, t_emb):
        
        # Upsampling
        x = self.up_block(x)
        x = torch.cat([x, down_out], dim=1)
        
        out = x
        for i in range(self.num_layers):
            resnet_input = out
            
            # Resnet Block
            out = self.conv1[i](out)
            out = out + self.te_block[i](t_emb)[:, :, None, None]
            out = self.conv2[i](out)
            out = out + self.res_block[i](resnet_input)

            # Self Attention
            out_attn = self.attn_block[i](out)
            out = out + out_attn
        
        return out

Die Klasse `Unet` implementiert schließlich die U-Net-Architektur. Sie kombiniert Downsampling-, Mittelkern- und Upsampling-Blöcke mit Zeit-Embedding, Residualverbindungen und Selbstaufmerksamkeit zur Verarbeitung und Rekonstruktion von Bildmerkmalen. Im folgenden Code wurde die Zahl der 'DownC', 'MidC' und 'UpC' layer von 2 auf 1 reduziert um Rechenzeit zu sparen. 

class Unet(nn.Module):
    """
    U-net architecture which is used to predict noise.
    U-net consists of Series of DownC blocks followed by MidC
    followed by UpC.
    """
    
    def __init__(self,
                 im_channels: int = 1, # Grayscale
                 down_ch: list = [32, 64, 128], #down_ch: list = [32, 64, 128, 256],
                 mid_ch: list = [128, 128], #mid_ch: list = [256, 256, 128],
                 up_ch: list[int] = [128, 64, 16], #[256, 128, 64, 16],
                 down_sample: list[bool] = [True, True], #down_sample: list[bool] = [True, True, False],
                 t_emb_dim: int = 128,
                 num_downc_layers:int = 1, #2 
                 num_midc_layers:int = 1, #2
                 num_upc_layers:int = 1 #2
                ):
        super(Unet, self).__init__()
        
        self.im_channels = im_channels
        self.down_ch = down_ch
        self.mid_ch = mid_ch
        self.up_ch = up_ch
        self.t_emb_dim = t_emb_dim
        self.down_sample = down_sample
        self.num_downc_layers = num_downc_layers
        self.num_midc_layers = num_midc_layers
        self.num_upc_layers = num_upc_layers
        
        self.up_sample = list(reversed(self.down_sample)) # [False, True, True]
        
        # Initial Convolution (1=>32)
        self.cv1 = nn.Conv2d(self.im_channels, self.down_ch[0], kernel_size=3, padding=1) #Faltung die die Zahl der Bildkanäle auf 32 vergrößert
        
        # Initial Time Embedding Projection
        self.t_proj = nn.Sequential(
            nn.Linear(self.t_emb_dim, self.t_emb_dim), #linear function dim 128 => dim 128
            nn.SiLU(), #activation
            nn.Linear(self.t_emb_dim, self.t_emb_dim) #linear function
        )
        
        # DownC Block
        self.downs = nn.ModuleList([
            DownC(
                self.down_ch[i], 
                self.down_ch[i+1], # Zahl der Bildkanäle wird schrittweise auf 64 => 128 => 256 vergrößert
                self.t_emb_dim, #128 
                self.num_downc_layers, #2 
                self.down_sample[i] # wenn down_sampling  => false => false
            ) for i in range(len(self.down_ch) - 1)
        ])
        
        # MidC Block
        self.mids = nn.ModuleList([
            MidC(
                self.mid_ch[i], 
                self.mid_ch[i+1], 
                self.t_emb_dim, 
                self.num_midc_layers
            ) for i in range(len(self.mid_ch) - 1)
        ])
        
        # UpC Block
        self.ups = nn.ModuleList([
            UpC(
                self.up_ch[i], 
                self.up_ch[i+1], 
                self.t_emb_dim, 
                self.num_upc_layers, 
                self.up_sample[i]
            ) for i in range(len(self.up_ch) - 1)
        ])
        
        # Final Convolution
        self.cv2 = nn.Sequential(
            nn.GroupNorm(8, self.up_ch[-1]), 
            nn.Conv2d(self.up_ch[-1], self.im_channels, kernel_size=3, padding=1)
        )
        
    def forward(self, x, t): # wird beim Aufruf von Unet automatisch ausgeführt
        
        out = self.cv1(x) # vergrößert die Zahl der Bildkanäle auf 32
        # Time Projection
        t_emb = get_time_embedding(t, self.t_emb_dim) #(128,128) Tensor bestehend aus sin und cos Werten von t bei unterschiedlichen Frequenzen
        t_emb = self.t_proj(t_emb) # ein batch von 128 time_embedings der Länge 128 wird (parallel) durch einen Linear Layer (128 -> 128) geschickt, aktiviert und erneut durch einen LL (128->128) geschickt 
        
        # DownC outputs
        down_outs = []
        
        for down in self.downs: #2x
            down_outs.append(out)
            out = down(out, t_emb)

        # MidC outputs
        for mid in self.mids:
            out = mid(out, t_emb)
        
        # UpC Blocks
        for up in self.ups:
            down_out = down_outs.pop()
            out = up(out, down_out, t_emb)
            
        # Final Conv
        out = self.cv2(out)
        
        return out

Damit ist das Netzwerk endlich definiert und wir können mit dem Training beginnen. Ich möchte für das Training den MNIST Datensatz aus Graustufenbildern der Größe 28 x 28 verwenden.


Zunächst definieren wir, wie wir unseren Trainingsdatensatz einlesen wollen.

import pandas as pd
import numpy as np
import torchvision
from torch.utils.data.dataset import Dataset
from PIL import Image

class CustomMnistDataset(Dataset):
    """
    Reads the MNIST data from csv file given file path.
    """
    def __init__(self, csv_path, num_datapoints = None):
        super(CustomMnistDataset, self).__init__()
        
        self.df = pd.read_csv(csv_path)
        
        if num_datapoints is not None:
            self.df = self.df.iloc[0:num_datapoints]
      
    def __len__(self):
        return len(self.df)
    
    def  __getitem__(self, index):
        # Read
        img = self.df.iloc[index].filter(regex='pixel').values
        img = np.reshape(img, (28, 28)).astype(np.uint8) 
        
        # Convert to Tensor
        img_tensor = torchvision.transforms.ToTensor()(img) # [0, 1]
        img_tensor = 2*img_tensor - 1 # [-1, 1]
        
        return img_tensor

Dann werden die Metadaten gesetzt,

class CONFIG:
    model_path = 'ddpm_unet.pth'
    train_csv_path = '/kaggle/input/train-mod30k/train_mod30k.csv'
    test_csv_path = '/kaggle/input/digit-recognizer/test.csv'
    generated_csv_path = 'mnist_generated_data.csv'
    num_epochs = 5 #50
    lr = 1e-4
    num_timesteps = 1000
    batch_size = 128
    img_size = 28 #28
    in_channels = 1
    num_img_to_generate = 5 #256

und anschließend der Trainingsloop definiert.

from torch.utils.data import DataLoader
from tqdm import tqdm

def train(cfg):
    
    # Dataset and Dataloader
    mnist_ds = CustomMnistDataset(cfg.train_csv_path)
    mnist_dl = DataLoader(mnist_ds, cfg.batch_size, shuffle=True)
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # CUDA (Compute Unified Device Architecture) ist eine von Nvidia entwickelte Programmierschnittstelle (API), mit der Programmteile durch den Grafikprozessor (GPU) abgearbeitet werden können.
    print(f'Device: {device}\n')
    
    # Initiate Model
    model = Unet().to(device)
    
    # Initialize Optimizer and Loss Function
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr) # Adam-Optimieren passst die Lernrate automatisch an. model.parameters() enthält alle lernbaren Gewichte
    criterion = torch.nn.MSELoss() #berechnet den Mean SquaredError also den mittleren Quadratischen Fehler zwischen zwei Tensoren (siehe Zeile 55)
    
    # Diffusion Forward Process to add noise
    dfp = DiffusionForwardProcess()
    
    # Best Loss
    best_eval_loss = float('inf')
    
    # Train
    for epoch in range(cfg.num_epochs):
        
        # For Loss Tracking
        losses = []
        
        # Set model to train mode
        model.train()
        
        # Loop over dataloader
        for imgs in tqdm(mnist_dl):
            
            imgs = imgs.to(device) #(128,1,28,28), batch_size in CONFIG ist 128, die bilder sind Graustufenbilder der Größe 28 x 28
             
            # Generate noise and timestamps
            noise = torch.randn_like(imgs).to(device) #für jeds Bild wird ein anderes epsilon (Rauschen) verwendet
            t = torch.randint(0, cfg.num_timesteps, (imgs.shape[0],)).to(device) #batch_size=imgs.shape[0]=128. Jedes Bild wird bis zu einem anderen Zeitpunkt t verrauscht.
            
            # Add noise to the images using Forward Process
            noisy_imgs = dfp.add_noise(imgs, noise, t) #erzeuge 128 verrauschte Bilder zum unterschiedlichen Zeitpunkten t (Es werden nicht alle Zeitpunkte für jedes Bild zum Training herangezogen)
            
            # Avoid Gradient Accumulation
            optimizer.zero_grad() #alte Gradienten löschen
            
            # Predict noise using U-net Model
            noise_pred = model(noisy_imgs, t)
            
            # Calculate Loss
            loss = criterion(noise_pred, noise) #berechne den mittleren quadratischen Fehler (MSE) zwischen dem Rauschen und dem vorhergesagten Rauschen
            
            losses.append(loss.item())
            
            # Backprop + Update model params
            loss.backward() # Gradienten berechnen
            optimizer.step() # Modellparameter des U-Netzes anpassen
        
        # Mean Loss
        mean_epoch_loss = np.mean(losses) #mittlerer Loss über alle Bilder
        
        # Display
        print('Epoch:{} | Loss : {:.4f}'.format(
            epoch + 1,
            mean_epoch_loss,
        ))
        
        # Save based on train-loss
        if mean_epoch_loss < best_eval_loss:
            best_eval_loss = mean_epoch_loss
            torch.save(model, cfg.model_path)
            
    print(f'Done training.....')

Dann startet das Training.

# Config
cfg = CONFIG()

# TRAIN
train(cfg)

Ist das Training beendet, und das Model abgespeichert, können neue Bilder generiert werden.

def generate(cfg):
    """
    Given Pretrained DDPM U-net model, Generate Real-life
    Images from noise by going backward step by step. i.e.,
    Mapping of Random Noise to Real-life images.
    """
    
    # Device setzen
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Diffusion Reverse Process initiieren
    drp = DiffusionReverseProcess()
    
    # Model in Evaluierungs Modus versetzten
    model = torch.load(cfg.model_path, weights_only=False).to(device) #`weights_only=True`
    model.eval()
    
    # Ein verrauschtes Bild mit N(0, 1) generieren
    xt = torch.randn(1, cfg.in_channels, cfg.img_size, cfg.img_size).to(device)
    
    # Schritt für Schritt entrauschen
    with torch.no_grad(): # Keine Berechnung von Gradienten bei der Bildgenerierung (Inferenz))
        for t in reversed(range(cfg.num_timesteps)):
            noise_pred = model(xt, torch.as_tensor(t).unsqueeze(0).to(device)) # Wandelt t in einen Tensor um, falls es z. B. ein einfacher Integer ist.unsqueeze(0): Fügt eine zusätzliche Dimension hinzu, sodass t die Form (1,) hat – passend zur Batch-Größe.
            xt, x0 = drp.sample_prev_timestep(xt, noise_pred, torch.as_tensor(t).to(device)) # Gibt x_(t-1) und x0 zurück

    # Bild skalieren
    xt = torch.clamp(xt, -1., 1.).detach().cpu() #Begrenzung der Werte im Tensor xt auf den Bereich von −1 bis +1.detach() trennt den Tensor vom Berechnungsgraphen. Das bedeutet: Es werden keine Gradienten mehr berechnet, und der Tensor ist nun „autonom“.cpu(): Verschiebt den Tensor vom GPU-Speicher zurück in den CPU-Speicher, damit er z. B. mit NumPy weiterbearbeitet werden kann.
    xt = (xt + 1) / 2
    
    return xt # gibt das 'entrauschte' Bild zurück

Auch die erzeugten Bilder werden abgespeichert, und anschließend dargestellt. Aus Zeitgründen habe ich in der Regel nur 5 Bilder generieren lassen um die Qualität der Ergebnisse bewerten zu können.

# Model laden und konfigurieren
cfg = CONFIG()

# Bild generieren
generated_imgs = []
for i in tqdm(range(cfg.num_img_to_generate)):
    xt = generate(cfg)
    xt = 255 * xt[0][0].numpy()
    generated_imgs.append(xt.astype(np.uint8).flatten())

# Erzeugte Daten speichern
generated_df = pd.DataFrame(generated_imgs, columns=[f'pixel{i}' for i in range(784)])
generated_df.to_csv(cfg.generated_csv_path, index=False)

# Visualisieren
from matplotlib import pyplot as plt
fig, axes = plt.subplots(1, 5, figsize=(5, 5))

# Jedes Bild in einem Subplot darstellen
for i, ax in enumerate(axes.flat):
    ax.imshow(np.reshape(generated_imgs[i], (28, 28)), cmap='gray')  

plt.tight_layout()  # Platz zwischen Subplots anpassen
plt.show()

Nach nur 5 Trainingsdurchläufen mit einer Laufzeit von  je 2 Minuten und 25 Sekunden und unter Verwendung von 42 Tausend Bildern, habe ich folgende Ergebnisse erhalten.


Dies stellt bereits eine deutliche Verbesserung gegenüber dem 'ersten Versuch ein Denoising Probalistic Model zu programmieren' dar. Bei Verwendung von 30 Tausend Bildern reduziert sich die  Laufzeit pro 'Epoche'  auf 1 Minute und  40 Sekunden. Die Qualität nimmt ab, ist aber immer noch akzeptabel.


Eine weitere Beschleunigung lässt sich erzielen indem wir die Zahl der 'Upc', 'Midc' und 'Downc' layer auf 1 reduzieren und den letzten bzw. ersten layer von 'Down' bzw. 'Up' in der Klasse 'Unet' weglassen.
Ein Durchlauf dauert dann nur noch 49 Sekunden.


Es sei darauf hingewiesen, dass sich diese schnellen Laufzeiten  nur unter Verwendung einer GPU (hier einer NVIDIA Tesla P100 ) erzielen lassen.  Das gesamte Programm lässt sich unterdiesem link als Kaggle Notebook ausführen. Im  Notebook finden sich auch links auf weiterführende Informationen.
Bei  ausschließlicher Verwendung einer CPU dauert der obige Trainingsdurchlauf zum Vergleich nicht 49 Sekunden sondern 42 Minuten, also ca. 51 mal so lang.
Letztlich habe ich das Netzwerk noch mit 40 Durchläufen trainiert. Der Verlauf des 'Losses' ist unten gezeigt.




Das Ergebnis

kommt dem ursprünglichen Datensatz 


schon sehr Nahe.

Viel Spaß beim selber programmieren und experimentieren...

Keine Kommentare:

Kommentar veröffentlichen