Diffusion Deep Dive Part 3: Coding a DDPM from Scratch

SW
Suchinthaka W.22 min read
diffusiongenerative-modelsdeep-learningddpmpytorchmachine-learning

In Part 1 we derived the DDPM training objective and the DDPM sampler. In Part 2 we extended that sampler to DDIM so the same network can generate in 50\sim 50 steps instead of 10001000. This post turns the math into running code.

Lsimple(θ)=Et,x0,ϵ ⁣[ϵϵθ ⁣(αˉtx0+1αˉtϵ,  t)2]\mathcal{L}_{\mathrm{simple}}(\theta) = \mathbb{E}_{t,\,\mathbf{x}_0,\,\boldsymbol{\epsilon}}\!\left[ \big\| \boldsymbol{\epsilon} - \boldsymbol{\epsilon}_\theta\!\big(\sqrt{\bar\alpha_t}\,\mathbf{x}_0 + \sqrt{1-\bar\alpha_t}\,\boldsymbol{\epsilon},\; t\big) \big\|^2 \right]

We build a DDPM that trains on MNIST in about an hour on a single GPU (or a few hours on CPU, if you are patient). The point is not state-of-the-art image quality; it is to have a minimal, readable reference where every line maps to an equation from Part 1 and every sampler change maps to Part 2.

What we will build:

  1. The noise schedule {βt,αt,αˉt}\{\beta_t, \alpha_t, \bar\alpha_t\}.
  2. The closed-form forward process q(xtx0)q(\mathbf{x}_t \mid \mathbf{x}_0).
  3. A small UNet ϵθ(xt,t)\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t) with sinusoidal time embeddings.
  4. The training loop (three lines of math, about ten of code).
  5. The iterative DDPM sampler (Algorithm 2 from Ho et al.).
  6. Practical stability tricks (EMA, gradient clipping, amp).

Reference: Ho, Jain, Abbeel, "Denoising Diffusion Probabilistic Models" (NeurIPS 2020). All equation numbers in this post refer to Part 1 of this series.

Setup

We will use PyTorch. Nothing else is required.

python
import math
from dataclasses import dataclass
 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
 
device = "cuda" if torch.cuda.is_available() else "cpu"

For the dataset, MNIST is the right starting point: it is small, the images are 28×2828 \times 28, and you can see whether the model is working within a few epochs. We pad to 32×3232 \times 32 (a power of two, which makes the UNet's down and up sampling clean) and scale pixels to [1,1][-1, 1] so the data looks roughly standard-normal, matching the noise we will add.

python
def get_loader(batch_size=128):
    transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),             # [0, 1]
        transforms.Normalize((0.5,), (0.5,)),  # [-1, 1]
    ])
    dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True,
                      num_workers=2, drop_last=True, pin_memory=True)

The Noise Schedule

Recall the definitions from Part 1: αt=1βt\alpha_t = 1 - \beta_t and αˉt=s=1tαs\bar\alpha_t = \prod_{s=1}^{t}\alpha_s. The noise schedule {βt}\{\beta_t\} controls how aggressively we destroy the signal at each step. Ho et al. use a linear schedule from β1=104\beta_1 = 10^{-4} to βT=0.02\beta_T = 0.02 with T=1000T = 1000. It is simple and works well enough for 32×3232 \times 32 images.

We precompute every quantity we will ever need, for every timestep, once at init. This avoids repeated cumulative products inside the training loop.

python
@dataclass
class Schedule:
    T: int
    betas: torch.Tensor          # (T,)
    alphas: torch.Tensor         # (T,)
    alphas_bar: torch.Tensor     # (T,)
    sqrt_alphas_bar: torch.Tensor
    sqrt_one_minus_alphas_bar: torch.Tensor
    sqrt_recip_alphas: torch.Tensor
    posterior_variance: torch.Tensor  # tilde-beta_t
 
 
def make_schedule(T=1000, beta_start=1e-4, beta_end=2e-2, device="cpu"):
    betas = torch.linspace(beta_start, beta_end, T, device=device)
    alphas = 1.0 - betas
    alphas_bar = torch.cumprod(alphas, dim=0)
    alphas_bar_prev = F.pad(alphas_bar[:-1], (1, 0), value=1.0)
 
    posterior_variance = betas * (1.0 - alphas_bar_prev) / (1.0 - alphas_bar)
 
    return Schedule(
        T=T,
        betas=betas,
        alphas=alphas,
        alphas_bar=alphas_bar,
        sqrt_alphas_bar=torch.sqrt(alphas_bar),
        sqrt_one_minus_alphas_bar=torch.sqrt(1.0 - alphas_bar),
        sqrt_recip_alphas=torch.sqrt(1.0 / alphas),
        posterior_variance=posterior_variance,
    )

Two notes worth calling out:

  • alphas_bar_prev is αˉt1\bar\alpha_{t-1}. We set αˉ0=1\bar\alpha_0 = 1 by convention (nothing has been corrupted at t=0t = 0), which is the pad(..., value=1.0) trick.
  • posterior_variance is β~t=1αˉt11αˉtβt\tilde\beta_t = \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t}\,\beta_t from Part 1. This is what we use as σt2\sigma_t^2 at sampling time.

The Forward Process (Closed Form)

From Part 1, equation (5), we can sample xt\mathbf{x}_t from x0\mathbf{x}_0 in one shot:

xt=αˉtx0+1αˉtϵ,ϵN(0,I).\mathbf{x}_t = \sqrt{\bar\alpha_t}\,\mathbf{x}_0 + \sqrt{1-\bar\alpha_t}\,\boldsymbol{\epsilon}, \qquad \boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}).

No loop over tt. This is the single most important efficiency property of diffusion training.

python
def gather(buf, t):
    """buf: (T,), t: (B,) -> (B, 1, 1, 1) for broadcasting over (B, C, H, W)."""
    return buf.gather(0, t).view(-1, 1, 1, 1)
 
 
def q_sample(x0, t, noise, sched):
    """Sample x_t from q(x_t | x_0) in closed form."""
    sqrt_ab = gather(sched.sqrt_alphas_bar, t)
    sqrt_one_minus_ab = gather(sched.sqrt_one_minus_alphas_bar, t)
    return sqrt_ab * x0 + sqrt_one_minus_ab * noise

gather looks trivial but it is load-bearing: each sample in the batch has its own tt, so we need per-sample coefficients broadcast over the spatial dimensions.

The Network: A Small UNet with Time Embeddings

The network ϵθ(xt,t)\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t) takes a noisy image and a timestep and outputs a prediction of the noise, with the same shape as the input. A UNet is the canonical choice: encoder-decoder with skip connections, which preserves spatial detail that would otherwise be lost during downsampling.

Sinusoidal Time Embeddings

The timestep tt is an integer in {1,,T}\{1, \ldots, T\}, but the network needs a continuous, information-rich representation. We use the same sinusoidal embedding as the Transformer (and for the same reason: it encodes position on many frequency scales at once). This embedding is then projected through an MLP and injected additively into every residual block.

python
class TimeEmbedding(nn.Module):
    def __init__(self, dim, hidden):
        super().__init__()
        assert dim % 2 == 0
        self.dim = dim
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden),
            nn.SiLU(),
            nn.Linear(hidden, hidden),
        )
 
    def forward(self, t):
        half = self.dim // 2
        freqs = torch.exp(
            -math.log(10000) * torch.arange(half, device=t.device) / half
        )
        args = t[:, None].float() * freqs[None]
        emb = torch.cat([args.sin(), args.cos()], dim=-1)   # (B, dim)
        return self.mlp(emb)                                # (B, hidden)

Residual Block with Time Conditioning

Each residual block does: group-norm, SiLU, conv; inject time embedding (one scalar per channel, added as bias); group-norm, SiLU, conv; plus a skip connection. This is a simplified version of the block in Ho et al.

python
class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, t_dim):
        super().__init__()
        self.norm1 = nn.GroupNorm(8, in_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.t_proj = nn.Linear(t_dim, out_ch)
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
 
    def forward(self, x, t_emb):
        h = self.conv1(F.silu(self.norm1(x)))
        h = h + self.t_proj(F.silu(t_emb))[:, :, None, None]
        h = self.conv2(F.silu(self.norm2(h)))
        return h + self.skip(x)

The UNet

Three resolutions: 3216832 \to 16 \to 8. Each level has one residual block going down, one at the bottleneck, and one going up, with skip connections concatenating the encoder features into the decoder. Keep it small. This is enough for MNIST and will train in minutes on a modern GPU.

python
class UNet(nn.Module):
    def __init__(self, in_ch=1, base=64, t_dim=128):
        super().__init__()
        self.time = TimeEmbedding(dim=t_dim, hidden=t_dim * 4)
        t_out = t_dim * 4
 
        self.in_conv = nn.Conv2d(in_ch, base, 3, padding=1)
 
        # Encoder
        self.down1 = ResBlock(base,     base,     t_out)
        self.pool1 = nn.Conv2d(base,     base * 2, 3, stride=2, padding=1)
        self.down2 = ResBlock(base * 2, base * 2, t_out)
        self.pool2 = nn.Conv2d(base * 2, base * 4, 3, stride=2, padding=1)
 
        # Bottleneck
        self.mid   = ResBlock(base * 4, base * 4, t_out)
 
        # Decoder (skip-connect by concatenation, so input channels double)
        self.up2   = nn.ConvTranspose2d(base * 4, base * 2, 4, stride=2, padding=1)
        self.dec2  = ResBlock(base * 4, base * 2, t_out)
        self.up1   = nn.ConvTranspose2d(base * 2, base,     4, stride=2, padding=1)
        self.dec1  = ResBlock(base * 2, base,     t_out)
 
        self.out_norm = nn.GroupNorm(8, base)
        self.out_conv = nn.Conv2d(base, in_ch, 3, padding=1)
 
    def forward(self, x, t):
        t_emb = self.time(t)
 
        h0 = self.in_conv(x)
        h1 = self.down1(h0, t_emb)        # (B, base,     32, 32)
        h2 = self.down2(self.pool1(h1), t_emb)  # (B, 2*base, 16, 16)
        hb = self.mid(self.pool2(h2), t_emb)    # (B, 4*base,  8,  8)
 
        u2 = self.up2(hb)                          # (B, 2*base, 16, 16)
        u2 = self.dec2(torch.cat([u2, h2], dim=1), t_emb)
        u1 = self.up1(u2)                          # (B, base,   32, 32)
        u1 = self.dec1(torch.cat([u1, h1], dim=1), t_emb)
 
        return self.out_conv(F.silu(self.out_norm(u1)))

A real implementation (Ho et al., OpenAI guided-diffusion, Stable Diffusion) adds self-attention at low resolutions, multiple residual blocks per level, and more channels. For the purposes of this post, the architecture above is intentionally stripped down.

Training

The training step is the one you came here for. Sample (x0,t,ϵ)(\mathbf{x}_0, t, \boldsymbol{\epsilon}), form xt\mathbf{x}_t in closed form, predict the noise, compute MSE. Four lines.

python
def train_step(model, x0, sched, optimizer):
    B = x0.size(0)
    t = torch.randint(0, sched.T, (B,), device=x0.device)
    noise = torch.randn_like(x0)
 
    x_t = q_sample(x0, t, noise, sched)
    noise_pred = model(x_t, t)
    loss = F.mse_loss(noise_pred, noise)
 
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    return loss.item()

Gradient clipping at norm 1.01.0 is a cheap safety net: occasional outlier batches (very small αˉt\bar\alpha_t, where xt\mathbf{x}_t is almost pure noise and the loss surface is flat) can produce huge gradients. Without clipping, a single bad step can destabilize training.

EMA of Weights

One practical detail that matters a lot for sample quality: keep an exponential moving average (EMA) of the model weights, and sample from the EMA copy instead of the live model. Diffusion losses are noisy (different tt every step), so the live weights oscillate; the EMA smooths this out. Typical decay is 0.99990.9999.

python
class EMA:
    def __init__(self, model, decay=0.9999):
        self.decay = decay
        self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}
 
    @torch.no_grad()
    def update(self, model):
        for k, v in model.state_dict().items():
            if v.dtype.is_floating_point:
                self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1.0 - self.decay)
            else:
                self.shadow[k].copy_(v)
 
    def copy_to(self, model):
        model.load_state_dict(self.shadow)

The Full Training Loop

python
def train(epochs=30, batch_size=128, lr=2e-4, T=1000):
    loader = get_loader(batch_size)
    sched  = make_schedule(T=T, device=device)
 
    model = UNet(in_ch=1, base=64).to(device)
    ema   = EMA(model, decay=0.9999)
    opt   = torch.optim.AdamW(model.parameters(), lr=lr)
 
    step = 0
    for epoch in range(epochs):
        for x0, _ in loader:
            x0 = x0.to(device, non_blocking=True)
            loss = train_step(model, x0, sched, opt)
            ema.update(model)
            step += 1
            if step % 200 == 0:
                print(f"epoch {epoch} step {step}: loss {loss:.4f}")
 
    return model, ema, sched

On an RTX 4090, 30 epochs of MNIST takes roughly 15 to 20 minutes and produces legible digits. You should see the training loss drop quickly from 1.0\sim 1.0 to 0.04\sim 0.04 within the first epoch and then crawl down slowly. Do not expect it to go to zero: the loss is an expectation over all noise levels, and the high-tt regime is irreducibly hard (the network cannot predict the noise when the signal is basically gone).

Sampling

Sampling is where the iterative structure comes back. We cannot skip timesteps the way we can in training: to sample xt1\mathbf{x}_{t-1} we need xt\mathbf{x}_t. Algorithm 2 from DDPM:

Starting from xTN(0,I)\mathbf{x}_T \sim \mathcal{N}(\mathbf{0}, \mathbf{I}), for t=T,T1,,1t = T, T-1, \ldots, 1:

xt1=1αt ⁣(xtβt1αˉtϵθ(xt,t))+σtz,\mathbf{x}_{t-1} = \frac{1}{\sqrt{\alpha_t}}\!\left(\mathbf{x}_t - \frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\,\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)\right) + \sigma_t \mathbf{z},

where zN(0,I)\mathbf{z} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) for t>1t > 1 and z=0\mathbf{z} = \mathbf{0} at the last step. We use σt2=β~t\sigma_t^2 = \tilde\beta_t (the posterior variance we precomputed).

The formula is exactly the μθ\boldsymbol{\mu}_\theta expression from Part 1, Step 4 of the noise-prediction parameterization, plus a Gaussian noise injection.

python
@torch.no_grad()
def sample(model, sched, n=16, img_size=32, channels=1, device=device):
    model.eval()
    x = torch.randn(n, channels, img_size, img_size, device=device)
 
    for t in reversed(range(sched.T)):
        t_batch = torch.full((n,), t, device=device, dtype=torch.long)
 
        eps = model(x, t_batch)
        coef = sched.betas[t] / sched.sqrt_one_minus_alphas_bar[t]
        mean = sched.sqrt_recip_alphas[t] * (x - coef * eps)
 
        if t > 0:
            noise = torch.randn_like(x)
            x = mean + torch.sqrt(sched.posterior_variance[t]) * noise
        else:
            x = mean  # no noise at the last step
 
    model.train()
    return x.clamp(-1, 1)

To sample from the EMA weights:

python
@torch.no_grad()
def sample_from_ema(model, ema, sched, **kwargs):
    # swap in EMA weights, sample, swap back
    backup = {k: v.clone() for k, v in model.state_dict().items()}
    ema.copy_to(model)
    imgs = sample(model, sched, **kwargs)
    model.load_state_dict(backup)
    return imgs

Why This Sampler Is Slow

This loop runs the network T=1000T = 1000 times per batch of samples. That is the fundamental cost of DDPM sampling and the largest practical drawback of vanilla diffusion. Two families of fixes:

  • DDIM (Song et al., 2021). Same trained network, but a different reverse process that can skip timesteps: 2525 to 5050 network evaluations with near-identical quality. Derived in Part 2. We implement it below.
  • Latent diffusion (Rombach et al., 2022). Do diffusion in a lower-dimensional latent space learned by a VAE. Each step is cheaper and you need fewer of them. Out of scope here.

Sampling Faster: DDIM

The derivation is in Part 2; the code is shorter than the math. Two additions to the file we already have:

  1. A helper that picks the sub-sampled timestep schedule τ=(τ0,,τS)\tau = (\tau_0, \ldots, \tau_S) from Part 2 equation (10)(10).
  2. The DDIM reverse step from Part 2 equation (11)(11), which plugs x^0\hat{\mathbf{x}}_0 from equation (5)(5) and the η\eta-parameterized σt\sigma_t from equation (7)(7) into a single loop.
python
def make_ddim_timesteps(T: int, num_steps: int) -> list[int]:
    """Ascending sub-sequence of timestep indices 0..T-1 of length ~num_steps+1.
    Includes T-1 as the last index so sampling starts from pure noise."""
    stride = max(1, T // num_steps)
    ts = list(range(0, T, stride))
    if ts[-1] != T - 1:
        ts.append(T - 1)
    return ts

The DDIM sampler itself:

python
@torch.no_grad()
def sample_ddim(model, sched, n=16, img_size=32, channels=1,
                num_steps=50, eta=0.0, device=device):
    """DDIM sampler (Song, Meng, Ermon 2021).
 
    num_steps: number of network evaluations (~50 is usually enough).
    eta=0.0:   fully deterministic sampler (probability-flow ODE).
    eta=1.0:   recovers DDPM on the sub-sampled grid.
    """
    model.eval()
    x = torch.randn(n, channels, img_size, img_size, device=device)
    ts = make_ddim_timesteps(sched.T, num_steps)  # ascending
 
    for i in reversed(range(len(ts))):
        t      = ts[i]
        t_prev = ts[i - 1] if i > 0 else -1
 
        t_batch = torch.full((n,), t, device=device, dtype=torch.long)
        eps = model(x, t_batch)
 
        ab_t    = sched.alphas_bar[t]
        ab_prev = sched.alphas_bar[t_prev] if t_prev >= 0 else torch.tensor(1.0, device=device)
 
        # (a) predict the clean image  [Part 2 eq. (5)]
        x0_pred = (x - torch.sqrt(1.0 - ab_t) * eps) / torch.sqrt(ab_t)
        x0_pred = x0_pred.clamp(-1, 1)                          # stability, optional
 
        # (b) sigma_t from the eta knob  [Part 2 eq. (7)]
        sigma_sq = (eta ** 2) * ((1.0 - ab_prev) / (1.0 - ab_t)) * (1.0 - ab_t / ab_prev)
        sigma    = torch.sqrt(sigma_sq.clamp(min=0.0))
 
        # (c) deterministic drift along the predicted noise  [Part 2 eq. (6)]
        dir_coef = torch.sqrt((1.0 - ab_prev - sigma_sq).clamp(min=0.0))
 
        # (d) stochastic noise injection; zero on the final step
        z = torch.randn_like(x) if i > 0 else torch.zeros_like(x)
 
        x = torch.sqrt(ab_prev) * x0_pred + dir_coef * eps + sigma * z
 
    model.train()
    return x.clamp(-1, 1)

Four labeled blocks, each corresponding to one term of Part 2's boxed update equation (6)(6). That is the entire new sampler.

Sanity Check: η=1\eta = 1 Must Recover DDPM

If we take num_steps = sched.T (so τ\tau is the full 0,1,,T10, 1, \ldots, T-1) and set eta = 1.0, DDIM and DDPM should produce statistically identical samples. Worth asserting in a unit test:

python
# With eta=1 and the full grid, DDIM reduces to DDPM (Part 2 §4.1).
sigma_sq_eta1 = ((1 - ab_prev) / (1 - ab_t)) * (1 - ab_t / ab_prev)
# This should equal sched.posterior_variance[t]  (= tilde_beta_t from Part 1)

If that equality does not hold to within floating-point slop, something is wrong with your schedule.

EMA Wrapper

To sample from the EMA copy with either sampler:

python
@torch.no_grad()
def sample_with_ema(model, ema, sched, fn=sample, **kwargs):
    """fn is `sample` (DDPM) or `sample_ddim`."""
    backup = {k: v.clone() for k, v in model.state_dict().items()}
    ema.copy_to(model)
    imgs = fn(model, sched, **kwargs)
    model.load_state_dict(backup)
    return imgs

DDPM vs DDIM: What to Expect

Samplernum_stepsWall-clock (RTX 4090, n=64)Visual quality
sample (DDPM)100010008\sim 8 sbaseline
sample_ddim(eta=1.0)50500.4\sim 0.4 snear-baseline
sample_ddim(eta=0.0)50500.4\sim 0.4 snear-baseline, deterministic
sample_ddim(eta=0.0)20200.15\sim 0.15 smild loss of detail
sample_ddim(eta=0.0)10100.08\sim 0.08 svisible artifacts on MNIST

The 20×20\times speed-up comes from doing 20×20\times fewer network forward passes; the UNet cost per step is unchanged.

Putting It All Together

python
if __name__ == "__main__":
    torch.manual_seed(0)
    model, ema, sched = train(epochs=30)
 
    # Slow, stochastic DDPM sampling (T=1000 network calls)
    ddpm_imgs = sample_with_ema(model, ema, sched, fn=sample,      n=64)
 
    # Fast, deterministic DDIM sampling (50 network calls)
    ddim_imgs = sample_with_ema(model, ema, sched, fn=sample_ddim, n=64,
                                num_steps=50, eta=0.0)
 
    from torchvision.utils import save_image
    save_image((ddpm_imgs + 1) / 2, "samples_ddpm.png", nrow=8)
    save_image((ddim_imgs + 1) / 2, "samples_ddim.png", nrow=8)

After 30 epochs you should get something like clean, legible MNIST digits from both samplers. If the samples look like noise or mostly uniform gray, the usual suspects are:

  • You forgot the [-1, 1] normalization; the model trains against N(0,1)\mathcal{N}(0, 1) noise, so unscaled [0,1][0, 1] pixels will produce visually dim samples.
  • You are sampling from the live model instead of EMA (quality will be visibly worse, especially early in training).
  • You are gathering by tt but forgot the .view(-1, 1, 1, 1) broadcast shape, so the schedule coefficients broadcast wrong.
  • You swapped the sign somewhere: the formula xtcoefϵθ\mathbf{x}_t - \text{coef} \cdot \boldsymbol{\epsilon}_\theta is correct (we are removing predicted noise), not xt+coefϵθ\mathbf{x}_t + \text{coef} \cdot \boldsymbol{\epsilon}_\theta.
  • For DDIM: a NaN in torch.sqrt(sigma_sq) usually means η>1\eta > 1 or a schedule bug; clamp to 0\geq 0.

How Each Piece Maps to Parts 1 and 2

MathCode
βt,αt,αˉt\beta_t, \alpha_t, \bar\alpha_t (Part 1)Schedule / make_schedule
β~t=1αˉt11αˉtβt\tilde\beta_t = \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t}\,\beta_t (Part 1 (27)(27))posterior_variance
xt=αˉtx0+1αˉtϵ\mathbf{x}_t = \sqrt{\bar\alpha_t}\mathbf{x}_0 + \sqrt{1-\bar\alpha_t}\boldsymbol{\epsilon} (Part 1 (41)(41))q_sample
ϵθ(xt,t)\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)UNet.forward
Lsimple=ϵϵθ2\mathcal{L}_{\mathrm{simple}} = \|\boldsymbol{\epsilon} - \boldsymbol{\epsilon}_\theta\|^2 (Part 1 (48)(48))F.mse_loss(noise_pred, noise) in train_step
μθ=1αt(xtβt1αˉtϵθ)\boldsymbol{\mu}_\theta = \tfrac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t - \tfrac{\beta_t}{\sqrt{1-\bar\alpha_t}}\boldsymbol{\epsilon}_\theta) (Part 1 (45)(45))mean = sqrt_recip_alphas[t] * (x - coef * eps) in sample
xt1=μθ+σtz\mathbf{x}_{t-1} = \boldsymbol{\mu}_\theta + \sigma_t\mathbf{z} (Part 1 (51)(51))x = mean + sqrt(posterior_variance[t]) * noise
x^0=(xt1αˉtϵθ)/αˉt\hat{\mathbf{x}}_0 = (\mathbf{x}_t - \sqrt{1-\bar\alpha_t}\boldsymbol{\epsilon}_\theta)/\sqrt{\bar\alpha_t} (Part 2 (5)(5))x0_pred = (x - sqrt(1-ab_t) * eps) / sqrt(ab_t) in sample_ddim
σt(η)\sigma_t(\eta) (Part 2 (7)(7))sigma_sq = eta**2 * ((1-ab_prev)/(1-ab_t)) * (1 - ab_t/ab_prev)
DDIM update (Part 2 (6)(6))x = sqrt(ab_prev)*x0_pred + dir_coef*eps + sigma*z
Sub-sequence τ\tau (Part 2 (10)(10))make_ddim_timesteps

If any of these rows confuses you, go back to the corresponding section; the translation is line-for-line.

Full Source: A Single Runnable File

Everything above condensed into one file. Save as ddpm_ddim.py, run with python ddpm_ddim.py, and it will train on MNIST and write samples_ddpm.png and samples_ddim.png.

python
"""
ddpm_ddim.py — Minimal DDPM training + DDPM/DDIM sampling, from scratch.
 
Companion code for:
  Part 1: /blog/diffusion-series-1-math-of-diffusion
  Part 2: /blog/diffusion-series-2-ddim
  Part 3: /blog/diffusion-series-3-coding-ddpm
 
Run:
  python ddpm_ddim.py
 
Dependencies: torch, torchvision.
"""
from __future__ import annotations
 
import math
from dataclasses import dataclass
 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
 
device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
# ---------------------------------------------------------------------------
# Data
# ---------------------------------------------------------------------------
def get_loader(batch_size: int = 128) -> DataLoader:
    transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),                    # [0, 1]
        transforms.Normalize((0.5,), (0.5,)),     # [-1, 1]
    ])
    dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True,
                      num_workers=2, drop_last=True, pin_memory=True)
 
 
# ---------------------------------------------------------------------------
# Schedule
# ---------------------------------------------------------------------------
@dataclass
class Schedule:
    T: int
    betas: torch.Tensor
    alphas: torch.Tensor
    alphas_bar: torch.Tensor
    sqrt_alphas_bar: torch.Tensor
    sqrt_one_minus_alphas_bar: torch.Tensor
    sqrt_recip_alphas: torch.Tensor
    posterior_variance: torch.Tensor
 
 
def make_schedule(T: int = 1000, beta_start: float = 1e-4,
                  beta_end: float = 2e-2, device: str = "cpu") -> Schedule:
    betas = torch.linspace(beta_start, beta_end, T, device=device)
    alphas = 1.0 - betas
    alphas_bar = torch.cumprod(alphas, dim=0)
    alphas_bar_prev = F.pad(alphas_bar[:-1], (1, 0), value=1.0)
 
    posterior_variance = betas * (1.0 - alphas_bar_prev) / (1.0 - alphas_bar)
 
    return Schedule(
        T=T,
        betas=betas,
        alphas=alphas,
        alphas_bar=alphas_bar,
        sqrt_alphas_bar=torch.sqrt(alphas_bar),
        sqrt_one_minus_alphas_bar=torch.sqrt(1.0 - alphas_bar),
        sqrt_recip_alphas=torch.sqrt(1.0 / alphas),
        posterior_variance=posterior_variance,
    )
 
 
def gather(buf: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    return buf.gather(0, t).view(-1, 1, 1, 1)
 
 
def q_sample(x0: torch.Tensor, t: torch.Tensor,
             noise: torch.Tensor, sched: Schedule) -> torch.Tensor:
    """x_t from q(x_t | x_0) in closed form.  [Part 1 eq. (41)]"""
    return gather(sched.sqrt_alphas_bar, t) * x0 + \
           gather(sched.sqrt_one_minus_alphas_bar, t) * noise
 
 
# ---------------------------------------------------------------------------
# UNet
# ---------------------------------------------------------------------------
class TimeEmbedding(nn.Module):
    def __init__(self, dim: int, hidden: int):
        super().__init__()
        assert dim % 2 == 0
        self.dim = dim
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden),
            nn.SiLU(),
            nn.Linear(hidden, hidden),
        )
 
    def forward(self, t: torch.Tensor) -> torch.Tensor:
        half = self.dim // 2
        freqs = torch.exp(
            -math.log(10000) * torch.arange(half, device=t.device) / half
        )
        args = t[:, None].float() * freqs[None]
        emb = torch.cat([args.sin(), args.cos()], dim=-1)
        return self.mlp(emb)
 
 
class ResBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, t_dim: int):
        super().__init__()
        self.norm1 = nn.GroupNorm(8, in_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.t_proj = nn.Linear(t_dim, out_ch)
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
 
    def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
        h = self.conv1(F.silu(self.norm1(x)))
        h = h + self.t_proj(F.silu(t_emb))[:, :, None, None]
        h = self.conv2(F.silu(self.norm2(h)))
        return h + self.skip(x)
 
 
class UNet(nn.Module):
    def __init__(self, in_ch: int = 1, base: int = 64, t_dim: int = 128):
        super().__init__()
        self.time = TimeEmbedding(dim=t_dim, hidden=t_dim * 4)
        t_out = t_dim * 4
 
        self.in_conv = nn.Conv2d(in_ch, base, 3, padding=1)
 
        self.down1 = ResBlock(base,     base,     t_out)
        self.pool1 = nn.Conv2d(base,     base * 2, 3, stride=2, padding=1)
        self.down2 = ResBlock(base * 2, base * 2, t_out)
        self.pool2 = nn.Conv2d(base * 2, base * 4, 3, stride=2, padding=1)
 
        self.mid   = ResBlock(base * 4, base * 4, t_out)
 
        self.up2   = nn.ConvTranspose2d(base * 4, base * 2, 4, stride=2, padding=1)
        self.dec2  = ResBlock(base * 4, base * 2, t_out)
        self.up1   = nn.ConvTranspose2d(base * 2, base,     4, stride=2, padding=1)
        self.dec1  = ResBlock(base * 2, base,     t_out)
 
        self.out_norm = nn.GroupNorm(8, base)
        self.out_conv = nn.Conv2d(base, in_ch, 3, padding=1)
 
    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        t_emb = self.time(t)
 
        h0 = self.in_conv(x)
        h1 = self.down1(h0, t_emb)
        h2 = self.down2(self.pool1(h1), t_emb)
        hb = self.mid(self.pool2(h2), t_emb)
 
        u2 = self.up2(hb)
        u2 = self.dec2(torch.cat([u2, h2], dim=1), t_emb)
        u1 = self.up1(u2)
        u1 = self.dec1(torch.cat([u1, h1], dim=1), t_emb)
 
        return self.out_conv(F.silu(self.out_norm(u1)))
 
 
# ---------------------------------------------------------------------------
# Training
# ---------------------------------------------------------------------------
def train_step(model: UNet, x0: torch.Tensor, sched: Schedule,
               optimizer: torch.optim.Optimizer) -> float:
    B = x0.size(0)
    t = torch.randint(0, sched.T, (B,), device=x0.device)
    noise = torch.randn_like(x0)
 
    x_t = q_sample(x0, t, noise, sched)
    noise_pred = model(x_t, t)
    loss = F.mse_loss(noise_pred, noise)
 
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    return loss.item()
 
 
class EMA:
    def __init__(self, model: nn.Module, decay: float = 0.9999):
        self.decay = decay
        self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}
 
    @torch.no_grad()
    def update(self, model: nn.Module) -> None:
        for k, v in model.state_dict().items():
            if v.dtype.is_floating_point:
                self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1.0 - self.decay)
            else:
                self.shadow[k].copy_(v)
 
    def copy_to(self, model: nn.Module) -> None:
        model.load_state_dict(self.shadow)
 
 
def train(epochs: int = 30, batch_size: int = 128, lr: float = 2e-4,
          T: int = 1000) -> tuple[UNet, EMA, Schedule]:
    loader = get_loader(batch_size)
    sched  = make_schedule(T=T, device=device)
 
    model = UNet(in_ch=1, base=64).to(device)
    ema   = EMA(model, decay=0.9999)
    opt   = torch.optim.AdamW(model.parameters(), lr=lr)
 
    step = 0
    for epoch in range(epochs):
        for x0, _ in loader:
            x0 = x0.to(device, non_blocking=True)
            loss = train_step(model, x0, sched, opt)
            ema.update(model)
            step += 1
            if step % 200 == 0:
                print(f"epoch {epoch} step {step}: loss {loss:.4f}")
 
    return model, ema, sched
 
 
# ---------------------------------------------------------------------------
# Samplers
# ---------------------------------------------------------------------------
@torch.no_grad()
def sample(model: UNet, sched: Schedule, n: int = 16,
           img_size: int = 32, channels: int = 1) -> torch.Tensor:
    """DDPM ancestral sampler (Ho et al. 2020, Algorithm 2).  [Part 1 eq. (51)]"""
    model.eval()
    x = torch.randn(n, channels, img_size, img_size, device=device)
 
    for t in reversed(range(sched.T)):
        t_batch = torch.full((n,), t, device=device, dtype=torch.long)
 
        eps = model(x, t_batch)
        coef = sched.betas[t] / sched.sqrt_one_minus_alphas_bar[t]
        mean = sched.sqrt_recip_alphas[t] * (x - coef * eps)
 
        if t > 0:
            noise = torch.randn_like(x)
            x = mean + torch.sqrt(sched.posterior_variance[t]) * noise
        else:
            x = mean
 
    model.train()
    return x.clamp(-1, 1)
 
 
def make_ddim_timesteps(T: int, num_steps: int) -> list[int]:
    """Ascending sub-sequence of timestep indices 0..T-1."""
    stride = max(1, T // num_steps)
    ts = list(range(0, T, stride))
    if ts[-1] != T - 1:
        ts.append(T - 1)
    return ts
 
 
@torch.no_grad()
def sample_ddim(model: UNet, sched: Schedule, n: int = 16,
                img_size: int = 32, channels: int = 1,
                num_steps: int = 50, eta: float = 0.0) -> torch.Tensor:
    """DDIM sampler (Song, Meng, Ermon 2021).  [Part 2 eq. (6), (7), (11)]
 
    eta=0.0 is deterministic (probability-flow ODE); eta=1.0 recovers DDPM
    on the sub-sampled grid."""
    model.eval()
    x = torch.randn(n, channels, img_size, img_size, device=device)
    ts = make_ddim_timesteps(sched.T, num_steps)
 
    for i in reversed(range(len(ts))):
        t      = ts[i]
        t_prev = ts[i - 1] if i > 0 else -1
 
        t_batch = torch.full((n,), t, device=device, dtype=torch.long)
        eps = model(x, t_batch)
 
        ab_t    = sched.alphas_bar[t]
        ab_prev = sched.alphas_bar[t_prev] if t_prev >= 0 else torch.tensor(1.0, device=device)
 
        x0_pred = ((x - torch.sqrt(1.0 - ab_t) * eps) / torch.sqrt(ab_t)).clamp(-1, 1)
 
        sigma_sq = (eta ** 2) * ((1.0 - ab_prev) / (1.0 - ab_t)) * (1.0 - ab_t / ab_prev)
        sigma    = torch.sqrt(sigma_sq.clamp(min=0.0))
        dir_coef = torch.sqrt((1.0 - ab_prev - sigma_sq).clamp(min=0.0))
 
        z = torch.randn_like(x) if i > 0 else torch.zeros_like(x)
        x = torch.sqrt(ab_prev) * x0_pred + dir_coef * eps + sigma * z
 
    model.train()
    return x.clamp(-1, 1)
 
 
@torch.no_grad()
def sample_with_ema(model: UNet, ema: EMA, sched: Schedule,
                    fn=sample, **kwargs) -> torch.Tensor:
    backup = {k: v.clone() for k, v in model.state_dict().items()}
    ema.copy_to(model)
    imgs = fn(model, sched, **kwargs)
    model.load_state_dict(backup)
    return imgs
 
 
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    torch.manual_seed(0)
    model, ema, sched = train(epochs=30)
 
    ddpm_imgs = sample_with_ema(model, ema, sched, fn=sample,      n=64)
    ddim_imgs = sample_with_ema(model, ema, sched, fn=sample_ddim, n=64,
                                num_steps=50, eta=0.0)
 
    save_image((ddpm_imgs + 1) / 2, "samples_ddpm.png", nrow=8)
    save_image((ddim_imgs + 1) / 2, "samples_ddim.png", nrow=8)
    print("Wrote samples_ddpm.png and samples_ddim.png")

The file is about 310310 lines. Three-quarters of it is the UNet; the DDPM side is roughly 4040 lines of substantive code, and DDIM adds about 3535.

What's Next

With this scaffolding, natural next steps are:

  • Better schedules. The cosine schedule of Nichol and Dhariwal (2021) keeps more signal at high tt and trains faster.
  • Classifier-free guidance. The one trick that made text-to-image diffusion actually work. A tiny change to training (drop the conditioning with 10% probability) and sampling (combine conditional and unconditional predictions).
  • Higher-order solvers. Heun, PLMS, and DPM-Solver push the step count from 5050 down to 1010 to 2020 without quality loss, using the same trained network.
  • DDIM inversion and interpolation. Because DDIM at η=0\eta=0 is an invertible map, you can encode a real image to a latent and slerp between latents to get semantic interpolations. See Part 2 §8.
  • Scaling up. CIFAR-10 or CelebA 64×6464 \times 64: increase UNet channels, add self-attention at 16×1616 \times 16 resolution, train longer. Code changes are modest; training budget is not.

The fundamentals do not change. Every diffusion model you will encounter, from Stable Diffusion to video diffusion to molecular generation, is a variation on this series: an ELBO that reduces to a noise-prediction MSE, a UNet (or Transformer) that minimizes it, and a reverse-process sampler with a stochastic/deterministic knob.

Share:
SW

Written by Suchinthaka Wanninayaka

AI/ML Researcher exploring semantic communications, diffusion models, and language model systems. Writing about deep learning from theory to production.

Responses

?

No responses yet. Be the first to share your thoughts!