Diffusion Deep Dive Part 3: Coding a DDPM from Scratch

SW
Suchinthaka W.13 min read
diffusiongenerative-modelsdeep-learningddpmpytorchmachine-learning

In Part 1 we derived the DDPM training objective and the DDPM sampler. In Part 2 we extended that sampler to DDIM so the same network can generate in 50\sim 50 steps instead of 10001000. This post turns the math into running code.

Lsimple(θ)=Et,x0,ϵ ⁣[ϵϵθ ⁣(αˉtx0+1αˉtϵ,  t)2]\mathcal{L}_{\mathrm{simple}}(\theta) = \mathbb{E}_{t,\,\mathbf{x}_0,\,\boldsymbol{\epsilon}}\!\left[ \big\| \boldsymbol{\epsilon} - \boldsymbol{\epsilon}_\theta\!\big(\sqrt{\bar\alpha_t}\,\mathbf{x}_0 + \sqrt{1-\bar\alpha_t}\,\boldsymbol{\epsilon},\; t\big) \big\|^2 \right]

We build a DDPM that trains on MNIST in about an hour on a single GPU (or a few hours on CPU, if you are patient). The point is not state-of-the-art image quality; it is to have a minimal, readable reference where every line maps to an equation from Part 1 and every sampler change maps to Part 2.

What we will build:

  1. The noise schedule {βt,αt,αˉt}\{\beta_t, \alpha_t, \bar\alpha_t\}.
  2. The closed-form forward process q(xtx0)q(\mathbf{x}_t \mid \mathbf{x}_0).
  3. A small UNet ϵθ(xt,t)\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t) with sinusoidal time embeddings.
  4. The training loop (three lines of math, about ten of code).
  5. The iterative DDPM sampler (Algorithm 2 from Ho et al.).
  6. Practical stability tricks (EMA, gradient clipping, amp).

Reference: Ho, Jain, Abbeel, "Denoising Diffusion Probabilistic Models" (NeurIPS 2020). All equation numbers in this post refer to Part 1 of this series.

Setup

We will use PyTorch. Nothing else is required.

python
import math
from dataclasses import dataclass
 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
 
device = "cuda" if torch.cuda.is_available() else "cpu"

For the dataset, MNIST is the right starting point: it is small, the images are 28×2828 \times 28, and you can see whether the model is working within a few epochs. We pad to 32×3232 \times 32 (a power of two, which makes the UNet's down and up sampling clean) and scale pixels to [1,1][-1, 1] so the data looks roughly standard-normal, matching the noise we will add.

python
def get_loader(batch_size=128):
    transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),             # [0, 1]
        transforms.Normalize((0.5,), (0.5,)),  # [-1, 1]
    ])
    dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True,
                      num_workers=2, drop_last=True, pin_memory=True)

The Noise Schedule

Recall the definitions from Part 1: αt=1βt\alpha_t = 1 - \beta_t and αˉt=s=1tαs\bar\alpha_t = \prod_{s=1}^{t}\alpha_s. The noise schedule {βt}\{\beta_t\} controls how aggressively we destroy the signal at each step. Ho et al. use a linear schedule from β1=104\beta_1 = 10^{-4} to βT=0.02\beta_T = 0.02 with T=1000T = 1000. It is simple and works well enough for 32×3232 \times 32 images.

We precompute every quantity we will ever need, for every timestep, once at init. This avoids repeated cumulative products inside the training loop.

python
@dataclass
class Schedule:
    T: int
    betas: torch.Tensor          # (T,)
    alphas: torch.Tensor         # (T,)
    alphas_bar: torch.Tensor     # (T,)
    sqrt_alphas_bar: torch.Tensor
    sqrt_one_minus_alphas_bar: torch.Tensor
    sqrt_recip_alphas: torch.Tensor
    posterior_variance: torch.Tensor  # tilde-beta_t
 
 
def make_schedule(T=1000, beta_start=1e-4, beta_end=2e-2, device="cpu"):
    betas = torch.linspace(beta_start, beta_end, T, device=device)
    alphas = 1.0 - betas
    alphas_bar = torch.cumprod(alphas, dim=0)
    alphas_bar_prev = F.pad(alphas_bar[:-1], (1, 0), value=1.0)
 
    posterior_variance = betas * (1.0 - alphas_bar_prev) / (1.0 - alphas_bar)
 
    return Schedule(
        T=T,
        betas=betas,
        alphas=alphas,
        alphas_bar=alphas_bar,
        sqrt_alphas_bar=torch.sqrt(alphas_bar),
        sqrt_one_minus_alphas_bar=torch.sqrt(1.0 - alphas_bar),
        sqrt_recip_alphas=torch.sqrt(1.0 / alphas),
        posterior_variance=posterior_variance,
    )

Two notes worth calling out:

  • alphas_bar_prev is αˉt1\bar\alpha_{t-1}. We set αˉ0=1\bar\alpha_0 = 1 by convention (nothing has been corrupted at t=0t = 0), which is the pad(..., value=1.0) trick.
  • posterior_variance is β~t=1αˉt11αˉtβt\tilde\beta_t = \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t}\,\beta_t from Part 1. This is what we use as σt2\sigma_t^2 at sampling time.

The Forward Process (Closed Form)

From Part 1, equation (5), we can sample xt\mathbf{x}_t from x0\mathbf{x}_0 in one shot:

xt=αˉtx0+1αˉtϵ,ϵN(0,I).\mathbf{x}_t = \sqrt{\bar\alpha_t}\,\mathbf{x}_0 + \sqrt{1-\bar\alpha_t}\,\boldsymbol{\epsilon}, \qquad \boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}).

No loop over tt. This is the single most important efficiency property of diffusion training.

python
def gather(buf, t):
    """buf: (T,), t: (B,) -> (B, 1, 1, 1) for broadcasting over (B, C, H, W)."""
    return buf.gather(0, t).view(-1, 1, 1, 1)
 
 
def q_sample(x0, t, noise, sched):
    """Sample x_t from q(x_t | x_0) in closed form."""
    sqrt_ab = gather(sched.sqrt_alphas_bar, t)
    sqrt_one_minus_ab = gather(sched.sqrt_one_minus_alphas_bar, t)
    return sqrt_ab * x0 + sqrt_one_minus_ab * noise

gather looks trivial but it is load-bearing: each sample in the batch has its own tt, so we need per-sample coefficients broadcast over the spatial dimensions.

The Network: A Small UNet with Time Embeddings

The network ϵθ(xt,t)\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t) takes a noisy image and a timestep and outputs a prediction of the noise, with the same shape as the input. A UNet is the canonical choice: encoder-decoder with skip connections, which preserves spatial detail that would otherwise be lost during downsampling.

Sinusoidal Time Embeddings

The timestep tt is an integer in {1,,T}\{1, \ldots, T\}, but the network needs a continuous, information-rich representation. We use the same sinusoidal embedding as the Transformer (and for the same reason: it encodes position on many frequency scales at once). This embedding is then projected through an MLP and injected additively into every residual block.

python
class TimeEmbedding(nn.Module):
    def __init__(self, dim, hidden):
        super().__init__()
        assert dim % 2 == 0
        self.dim = dim
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden),
            nn.SiLU(),
            nn.Linear(hidden, hidden),
        )
 
    def forward(self, t):
        half = self.dim // 2
        freqs = torch.exp(
            -math.log(10000) * torch.arange(half, device=t.device) / half
        )
        args = t[:, None].float() * freqs[None]
        emb = torch.cat([args.sin(), args.cos()], dim=-1)   # (B, dim)
        return self.mlp(emb)                                # (B, hidden)

Residual Block with Time Conditioning

Each residual block does: group-norm, SiLU, conv; inject time embedding (one scalar per channel, added as bias); group-norm, SiLU, conv; plus a skip connection. This is a simplified version of the block in Ho et al.

python
class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, t_dim):
        super().__init__()
        self.norm1 = nn.GroupNorm(8, in_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.t_proj = nn.Linear(t_dim, out_ch)
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
 
    def forward(self, x, t_emb):
        h = self.conv1(F.silu(self.norm1(x)))
        h = h + self.t_proj(F.silu(t_emb))[:, :, None, None]
        h = self.conv2(F.silu(self.norm2(h)))
        return h + self.skip(x)

The UNet

Three resolutions: 3216832 \to 16 \to 8. Each level has one residual block going down, one at the bottleneck, and one going up, with skip connections concatenating the encoder features into the decoder. Keep it small. This is enough for MNIST and will train in minutes on a modern GPU.

python
class UNet(nn.Module):
    def __init__(self, in_ch=1, base=64, t_dim=128):
        super().__init__()
        self.time = TimeEmbedding(dim=t_dim, hidden=t_dim * 4)
        t_out = t_dim * 4
 
        self.in_conv = nn.Conv2d(in_ch, base, 3, padding=1)
 
        # Encoder
        self.down1 = ResBlock(base,     base,     t_out)
        self.pool1 = nn.Conv2d(base,     base * 2, 3, stride=2, padding=1)
        self.down2 = ResBlock(base * 2, base * 2, t_out)
        self.pool2 = nn.Conv2d(base * 2, base * 4, 3, stride=2, padding=1)
 
        # Bottleneck
        self.mid   = ResBlock(base * 4, base * 4, t_out)
 
        # Decoder (skip-connect by concatenation, so input channels double)
        self.up2   = nn.ConvTranspose2d(base * 4, base * 2, 4, stride=2, padding=1)
        self.dec2  = ResBlock(base * 4, base * 2, t_out)
        self.up1   = nn.ConvTranspose2d(base * 2, base,     4, stride=2, padding=1)
        self.dec1  = ResBlock(base * 2, base,     t_out)
 
        self.out_norm = nn.GroupNorm(8, base)
        self.out_conv = nn.Conv2d(base, in_ch, 3, padding=1)
 
    def forward(self, x, t):
        t_emb = self.time(t)
 
        h0 = self.in_conv(x)
        h1 = self.down1(h0, t_emb)        # (B, base,     32, 32)
        h2 = self.down2(self.pool1(h1), t_emb)  # (B, 2*base, 16, 16)
        hb = self.mid(self.pool2(h2), t_emb)    # (B, 4*base,  8,  8)
 
        u2 = self.up2(hb)                          # (B, 2*base, 16, 16)
        u2 = self.dec2(torch.cat([u2, h2], dim=1), t_emb)
        u1 = self.up1(u2)                          # (B, base,   32, 32)
        u1 = self.dec1(torch.cat([u1, h1], dim=1), t_emb)
 
        return self.out_conv(F.silu(self.out_norm(u1)))

A real implementation (Ho et al., OpenAI guided-diffusion, Stable Diffusion) adds self-attention at low resolutions, multiple residual blocks per level, and more channels. For the purposes of this post, the architecture above is intentionally stripped down.

Training

The training step is the one you came here for. Sample (x0,t,ϵ)(\mathbf{x}_0, t, \boldsymbol{\epsilon}), form xt\mathbf{x}_t in closed form, predict the noise, compute MSE. Four lines.

python
def train_step(model, x0, sched, optimizer):
    B = x0.size(0)
    t = torch.randint(0, sched.T, (B,), device=x0.device)
    noise = torch.randn_like(x0)
 
    x_t = q_sample(x0, t, noise, sched)
    noise_pred = model(x_t, t)
    loss = F.mse_loss(noise_pred, noise)
 
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    return loss.item()

Gradient clipping at norm 1.01.0 is a cheap safety net: occasional outlier batches (very small αˉt\bar\alpha_t, where xt\mathbf{x}_t is almost pure noise and the loss surface is flat) can produce huge gradients. Without clipping, a single bad step can destabilize training.

EMA of Weights

One practical detail that matters a lot for sample quality: keep an exponential moving average (EMA) of the model weights, and sample from the EMA copy instead of the live model. Diffusion losses are noisy (different tt every step), so the live weights oscillate; the EMA smooths this out. Typical decay is 0.99990.9999.

python
class EMA:
    def __init__(self, model, decay=0.9999):
        self.decay = decay
        self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}
 
    @torch.no_grad()
    def update(self, model):
        for k, v in model.state_dict().items():
            if v.dtype.is_floating_point:
                self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1.0 - self.decay)
            else:
                self.shadow[k].copy_(v)
 
    def copy_to(self, model):
        model.load_state_dict(self.shadow)

The Full Training Loop

python
def train(epochs=30, batch_size=128, lr=2e-4, T=1000):
    loader = get_loader(batch_size)
    sched  = make_schedule(T=T, device=device)
 
    model = UNet(in_ch=1, base=64).to(device)
    ema   = EMA(model, decay=0.9999)
    opt   = torch.optim.AdamW(model.parameters(), lr=lr)
 
    step = 0
    for epoch in range(epochs):
        for x0, _ in loader:
            x0 = x0.to(device, non_blocking=True)
            loss = train_step(model, x0, sched, opt)
            ema.update(model)
            step += 1
            if step % 200 == 0:
                print(f"epoch {epoch} step {step}: loss {loss:.4f}")
 
    return model, ema, sched

On an RTX 4090, 30 epochs of MNIST takes roughly 15 to 20 minutes and produces legible digits. You should see the training loss drop quickly from 1.0\sim 1.0 to 0.04\sim 0.04 within the first epoch and then crawl down slowly. Do not expect it to go to zero: the loss is an expectation over all noise levels, and the high-tt regime is irreducibly hard (the network cannot predict the noise when the signal is basically gone).

Sampling

Sampling is where the iterative structure comes back. We cannot skip timesteps the way we can in training: to sample xt1\mathbf{x}_{t-1} we need xt\mathbf{x}_t. Algorithm 2 from DDPM:

Starting from xTN(0,I)\mathbf{x}_T \sim \mathcal{N}(\mathbf{0}, \mathbf{I}), for t=T,T1,,1t = T, T-1, \ldots, 1:

xt1=1αt ⁣(xtβt1αˉtϵθ(xt,t))+σtz,\mathbf{x}_{t-1} = \frac{1}{\sqrt{\alpha_t}}\!\left(\mathbf{x}_t - \frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\,\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)\right) + \sigma_t \mathbf{z},

where zN(0,I)\mathbf{z} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) for t>1t > 1 and z=0\mathbf{z} = \mathbf{0} at the last step. We use σt2=β~t\sigma_t^2 = \tilde\beta_t (the posterior variance we precomputed).

The formula is exactly the μθ\boldsymbol{\mu}_\theta expression from Part 1, Step 4 of the noise-prediction parameterization, plus a Gaussian noise injection.

python
@torch.no_grad()
def sample(model, sched, n=16, img_size=32, channels=1, device=device):
    model.eval()
    x = torch.randn(n, channels, img_size, img_size, device=device)
 
    for t in reversed(range(sched.T)):
        t_batch = torch.full((n,), t, device=device, dtype=torch.long)
 
        eps = model(x, t_batch)
        coef = sched.betas[t] / sched.sqrt_one_minus_alphas_bar[t]
        mean = sched.sqrt_recip_alphas[t] * (x - coef * eps)
 
        if t > 0:
            noise = torch.randn_like(x)
            x = mean + torch.sqrt(sched.posterior_variance[t]) * noise
        else:
            x = mean  # no noise at the last step
 
    model.train()
    return x.clamp(-1, 1)

To sample from the EMA weights:

python
@torch.no_grad()
def sample_from_ema(model, ema, sched, **kwargs):
    # swap in EMA weights, sample, swap back
    backup = {k: v.clone() for k, v in model.state_dict().items()}
    ema.copy_to(model)
    imgs = sample(model, sched, **kwargs)
    model.load_state_dict(backup)
    return imgs

Why Sampling Is Slow (and How to Make It Faster)

This loop runs the network T=1000T = 1000 times per batch of samples. That is the fundamental cost of DDPM sampling and is the largest practical drawback of vanilla diffusion. Two families of improvements:

  • DDIM (Song et al., 2021). A deterministic sampler that uses the same trained model but can skip timesteps, often taking 2525 to 5050 network evaluations instead of 10001000 with minimal quality loss. Worth its own post.
  • Latent diffusion (Rombach et al., 2022). Do the diffusion process in a lower-dimensional latent space learned by a VAE, rather than on raw pixels. Each step is cheaper and you need fewer of them.

For MNIST at 32×3232 \times 32, the naive sampler is fine.

Putting It All Together

python
if __name__ == "__main__":
    torch.manual_seed(0)
    model, ema, sched = train(epochs=30)
    imgs = sample_from_ema(model, ema, sched, n=64)
 
    # imgs is in [-1, 1]; rescale to [0, 1] for display
    imgs = (imgs + 1) / 2
 
    from torchvision.utils import save_image
    save_image(imgs, "samples.png", nrow=8)

After 30 epochs you should get something like clean, legible MNIST digits. If the samples look like noise or mostly uniform gray, the usual suspects are:

  • You forgot the [-1, 1] normalization; the model trains against N(0,1)\mathcal{N}(0, 1) noise, so unscaled [0,1][0, 1] pixels will produce visually dim samples.
  • You are sampling from the live model instead of EMA (quality will be visibly worse, especially early in training).
  • You are gathering by tt but forgot the .view(-1, 1, 1, 1) broadcast shape, so the schedule coefficients broadcast wrong.
  • You swapped the sign somewhere: the formula xtcoefϵθ\mathbf{x}_t - \text{coef} \cdot \boldsymbol{\epsilon}_\theta is correct (we are removing predicted noise), not xt+coefϵθ\mathbf{x}_t + \text{coef} \cdot \boldsymbol{\epsilon}_\theta.

How Each Piece Maps to Part 1

To make the correspondence explicit:

Math in Part 1Code in Part 2
βt,αt,αˉt\beta_t, \alpha_t, \bar\alpha_tSchedule / make_schedule
β~t=1αˉt11αˉtβt\tilde\beta_t = \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t}\,\beta_tposterior_variance
xt=αˉtx0+1αˉtϵ\mathbf{x}_t = \sqrt{\bar\alpha_t}\,\mathbf{x}_0 + \sqrt{1-\bar\alpha_t}\,\boldsymbol{\epsilon}q_sample
ϵθ(xt,t)\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)UNet.forward
Lsimple=ϵϵθ2\mathcal{L}_{\mathrm{simple}} = \|\boldsymbol{\epsilon} - \boldsymbol{\epsilon}_\theta\|^2F.mse_loss(noise_pred, noise) in train_step
μθ=1αt(xtβt1αˉtϵθ)\boldsymbol{\mu}_\theta = \tfrac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t - \tfrac{\beta_t}{\sqrt{1-\bar\alpha_t}}\boldsymbol{\epsilon}_\theta)mean = sqrt_recip_alphas[t] * (x - coef * eps) in sample
xt1=μθ+σtz\mathbf{x}_{t-1} = \boldsymbol{\mu}_\theta + \sigma_t \mathbf{z}x = mean + sqrt(posterior_variance[t]) * noise

If any of these rows confuses you, go back to the corresponding section of Part 1; the translation is line-for-line.

What's Next

With this scaffolding, natural next steps are:

  • Better schedules. The cosine schedule of Nichol and Dhariwal (2021) keeps more signal at high tt and trains faster.
  • Classifier-free guidance. The one trick that made text-to-image diffusion actually work. A tiny change to training (drop the conditioning with 10% probability) and sampling (combine conditional and unconditional predictions).
  • Faster samplers. DDIM, DPM-Solver, and heun-style second-order integrators get you from 1000 steps down to 10 to 50 without retraining.
  • Scaling up. CIFAR-10 or CelebA 64×6464 \times 64: increase UNet channels, add self-attention at 16×1616 \times 16 resolution, train longer. The code changes are modest; the training budget is not.

The fundamentals do not change. Every diffusion model you will encounter, from Stable Diffusion to video diffusion to molecular generation, is a variation on the two posts in this series: an ELBO that reduces to a noise-prediction MSE, and a UNet (or Transformer) trained to minimize it.

Share:
SW

Written by Suchinthaka Wanninayaka

AI/ML Researcher exploring semantic communications, diffusion models, and language model systems. Writing about deep learning from theory to production.

Responses

?

No responses yet. Be the first to share your thoughts!