Diffusion Deep Dive Part 3: Coding a DDPM from Scratch
In Part 1 we derived the DDPM training objective and the DDPM sampler. In Part 2 we extended that sampler to DDIM so the same network can generate in steps instead of . This post turns the math into running code.
We build a DDPM that trains on MNIST in about an hour on a single GPU (or a few hours on CPU, if you are patient). The point is not state-of-the-art image quality; it is to have a minimal, readable reference where every line maps to an equation from Part 1 and every sampler change maps to Part 2.
What we will build:
- The noise schedule .
- The closed-form forward process .
- A small UNet with sinusoidal time embeddings.
- The training loop (three lines of math, about ten of code).
- The iterative DDPM sampler (Algorithm 2 from Ho et al.).
- Practical stability tricks (EMA, gradient clipping, amp).
Reference: Ho, Jain, Abbeel, "Denoising Diffusion Probabilistic Models" (NeurIPS 2020). All equation numbers in this post refer to Part 1 of this series.
Setup
We will use PyTorch. Nothing else is required.
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
device = "cuda" if torch.cuda.is_available() else "cpu"For the dataset, MNIST is the right starting point: it is small, the images are , and you can see whether the model is working within a few epochs. We pad to (a power of two, which makes the UNet's down and up sampling clean) and scale pixels to so the data looks roughly standard-normal, matching the noise we will add.
def get_loader(batch_size=128):
transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(), # [0, 1]
transforms.Normalize((0.5,), (0.5,)), # [-1, 1]
])
dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
return DataLoader(dataset, batch_size=batch_size, shuffle=True,
num_workers=2, drop_last=True, pin_memory=True)The Noise Schedule
Recall the definitions from Part 1: and . The noise schedule controls how aggressively we destroy the signal at each step. Ho et al. use a linear schedule from to with . It is simple and works well enough for images.
We precompute every quantity we will ever need, for every timestep, once at init. This avoids repeated cumulative products inside the training loop.
@dataclass
class Schedule:
T: int
betas: torch.Tensor # (T,)
alphas: torch.Tensor # (T,)
alphas_bar: torch.Tensor # (T,)
sqrt_alphas_bar: torch.Tensor
sqrt_one_minus_alphas_bar: torch.Tensor
sqrt_recip_alphas: torch.Tensor
posterior_variance: torch.Tensor # tilde-beta_t
def make_schedule(T=1000, beta_start=1e-4, beta_end=2e-2, device="cpu"):
betas = torch.linspace(beta_start, beta_end, T, device=device)
alphas = 1.0 - betas
alphas_bar = torch.cumprod(alphas, dim=0)
alphas_bar_prev = F.pad(alphas_bar[:-1], (1, 0), value=1.0)
posterior_variance = betas * (1.0 - alphas_bar_prev) / (1.0 - alphas_bar)
return Schedule(
T=T,
betas=betas,
alphas=alphas,
alphas_bar=alphas_bar,
sqrt_alphas_bar=torch.sqrt(alphas_bar),
sqrt_one_minus_alphas_bar=torch.sqrt(1.0 - alphas_bar),
sqrt_recip_alphas=torch.sqrt(1.0 / alphas),
posterior_variance=posterior_variance,
)Two notes worth calling out:
alphas_bar_previs . We set by convention (nothing has been corrupted at ), which is thepad(..., value=1.0)trick.posterior_varianceis from Part 1. This is what we use as at sampling time.
The Forward Process (Closed Form)
From Part 1, equation (5), we can sample from in one shot:
No loop over . This is the single most important efficiency property of diffusion training.
def gather(buf, t):
"""buf: (T,), t: (B,) -> (B, 1, 1, 1) for broadcasting over (B, C, H, W)."""
return buf.gather(0, t).view(-1, 1, 1, 1)
def q_sample(x0, t, noise, sched):
"""Sample x_t from q(x_t | x_0) in closed form."""
sqrt_ab = gather(sched.sqrt_alphas_bar, t)
sqrt_one_minus_ab = gather(sched.sqrt_one_minus_alphas_bar, t)
return sqrt_ab * x0 + sqrt_one_minus_ab * noisegather looks trivial but it is load-bearing: each sample in the batch has its own , so we need per-sample coefficients broadcast over the spatial dimensions.
The Network: A Small UNet with Time Embeddings
The network takes a noisy image and a timestep and outputs a prediction of the noise, with the same shape as the input. A UNet is the canonical choice: encoder-decoder with skip connections, which preserves spatial detail that would otherwise be lost during downsampling.
Sinusoidal Time Embeddings
The timestep is an integer in , but the network needs a continuous, information-rich representation. We use the same sinusoidal embedding as the Transformer (and for the same reason: it encodes position on many frequency scales at once). This embedding is then projected through an MLP and injected additively into every residual block.
class TimeEmbedding(nn.Module):
def __init__(self, dim, hidden):
super().__init__()
assert dim % 2 == 0
self.dim = dim
self.mlp = nn.Sequential(
nn.Linear(dim, hidden),
nn.SiLU(),
nn.Linear(hidden, hidden),
)
def forward(self, t):
half = self.dim // 2
freqs = torch.exp(
-math.log(10000) * torch.arange(half, device=t.device) / half
)
args = t[:, None].float() * freqs[None]
emb = torch.cat([args.sin(), args.cos()], dim=-1) # (B, dim)
return self.mlp(emb) # (B, hidden)Residual Block with Time Conditioning
Each residual block does: group-norm, SiLU, conv; inject time embedding (one scalar per channel, added as bias); group-norm, SiLU, conv; plus a skip connection. This is a simplified version of the block in Ho et al.
class ResBlock(nn.Module):
def __init__(self, in_ch, out_ch, t_dim):
super().__init__()
self.norm1 = nn.GroupNorm(8, in_ch)
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.t_proj = nn.Linear(t_dim, out_ch)
self.norm2 = nn.GroupNorm(8, out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
def forward(self, x, t_emb):
h = self.conv1(F.silu(self.norm1(x)))
h = h + self.t_proj(F.silu(t_emb))[:, :, None, None]
h = self.conv2(F.silu(self.norm2(h)))
return h + self.skip(x)The UNet
Three resolutions: . Each level has one residual block going down, one at the bottleneck, and one going up, with skip connections concatenating the encoder features into the decoder. Keep it small. This is enough for MNIST and will train in minutes on a modern GPU.
class UNet(nn.Module):
def __init__(self, in_ch=1, base=64, t_dim=128):
super().__init__()
self.time = TimeEmbedding(dim=t_dim, hidden=t_dim * 4)
t_out = t_dim * 4
self.in_conv = nn.Conv2d(in_ch, base, 3, padding=1)
# Encoder
self.down1 = ResBlock(base, base, t_out)
self.pool1 = nn.Conv2d(base, base * 2, 3, stride=2, padding=1)
self.down2 = ResBlock(base * 2, base * 2, t_out)
self.pool2 = nn.Conv2d(base * 2, base * 4, 3, stride=2, padding=1)
# Bottleneck
self.mid = ResBlock(base * 4, base * 4, t_out)
# Decoder (skip-connect by concatenation, so input channels double)
self.up2 = nn.ConvTranspose2d(base * 4, base * 2, 4, stride=2, padding=1)
self.dec2 = ResBlock(base * 4, base * 2, t_out)
self.up1 = nn.ConvTranspose2d(base * 2, base, 4, stride=2, padding=1)
self.dec1 = ResBlock(base * 2, base, t_out)
self.out_norm = nn.GroupNorm(8, base)
self.out_conv = nn.Conv2d(base, in_ch, 3, padding=1)
def forward(self, x, t):
t_emb = self.time(t)
h0 = self.in_conv(x)
h1 = self.down1(h0, t_emb) # (B, base, 32, 32)
h2 = self.down2(self.pool1(h1), t_emb) # (B, 2*base, 16, 16)
hb = self.mid(self.pool2(h2), t_emb) # (B, 4*base, 8, 8)
u2 = self.up2(hb) # (B, 2*base, 16, 16)
u2 = self.dec2(torch.cat([u2, h2], dim=1), t_emb)
u1 = self.up1(u2) # (B, base, 32, 32)
u1 = self.dec1(torch.cat([u1, h1], dim=1), t_emb)
return self.out_conv(F.silu(self.out_norm(u1)))A real implementation (Ho et al., OpenAI guided-diffusion, Stable Diffusion) adds self-attention at low resolutions, multiple residual blocks per level, and more channels. For the purposes of this post, the architecture above is intentionally stripped down.
Training
The training step is the one you came here for. Sample , form in closed form, predict the noise, compute MSE. Four lines.
def train_step(model, x0, sched, optimizer):
B = x0.size(0)
t = torch.randint(0, sched.T, (B,), device=x0.device)
noise = torch.randn_like(x0)
x_t = q_sample(x0, t, noise, sched)
noise_pred = model(x_t, t)
loss = F.mse_loss(noise_pred, noise)
optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
return loss.item()Gradient clipping at norm is a cheap safety net: occasional outlier batches (very small , where is almost pure noise and the loss surface is flat) can produce huge gradients. Without clipping, a single bad step can destabilize training.
EMA of Weights
One practical detail that matters a lot for sample quality: keep an exponential moving average (EMA) of the model weights, and sample from the EMA copy instead of the live model. Diffusion losses are noisy (different every step), so the live weights oscillate; the EMA smooths this out. Typical decay is .
class EMA:
def __init__(self, model, decay=0.9999):
self.decay = decay
self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}
@torch.no_grad()
def update(self, model):
for k, v in model.state_dict().items():
if v.dtype.is_floating_point:
self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1.0 - self.decay)
else:
self.shadow[k].copy_(v)
def copy_to(self, model):
model.load_state_dict(self.shadow)The Full Training Loop
def train(epochs=30, batch_size=128, lr=2e-4, T=1000):
loader = get_loader(batch_size)
sched = make_schedule(T=T, device=device)
model = UNet(in_ch=1, base=64).to(device)
ema = EMA(model, decay=0.9999)
opt = torch.optim.AdamW(model.parameters(), lr=lr)
step = 0
for epoch in range(epochs):
for x0, _ in loader:
x0 = x0.to(device, non_blocking=True)
loss = train_step(model, x0, sched, opt)
ema.update(model)
step += 1
if step % 200 == 0:
print(f"epoch {epoch} step {step}: loss {loss:.4f}")
return model, ema, schedOn an RTX 4090, 30 epochs of MNIST takes roughly 15 to 20 minutes and produces legible digits. You should see the training loss drop quickly from to within the first epoch and then crawl down slowly. Do not expect it to go to zero: the loss is an expectation over all noise levels, and the high- regime is irreducibly hard (the network cannot predict the noise when the signal is basically gone).
Sampling
Sampling is where the iterative structure comes back. We cannot skip timesteps the way we can in training: to sample we need . Algorithm 2 from DDPM:
Starting from , for :
where for and at the last step. We use (the posterior variance we precomputed).
The formula is exactly the expression from Part 1, Step 4 of the noise-prediction parameterization, plus a Gaussian noise injection.
@torch.no_grad()
def sample(model, sched, n=16, img_size=32, channels=1, device=device):
model.eval()
x = torch.randn(n, channels, img_size, img_size, device=device)
for t in reversed(range(sched.T)):
t_batch = torch.full((n,), t, device=device, dtype=torch.long)
eps = model(x, t_batch)
coef = sched.betas[t] / sched.sqrt_one_minus_alphas_bar[t]
mean = sched.sqrt_recip_alphas[t] * (x - coef * eps)
if t > 0:
noise = torch.randn_like(x)
x = mean + torch.sqrt(sched.posterior_variance[t]) * noise
else:
x = mean # no noise at the last step
model.train()
return x.clamp(-1, 1)To sample from the EMA weights:
@torch.no_grad()
def sample_from_ema(model, ema, sched, **kwargs):
# swap in EMA weights, sample, swap back
backup = {k: v.clone() for k, v in model.state_dict().items()}
ema.copy_to(model)
imgs = sample(model, sched, **kwargs)
model.load_state_dict(backup)
return imgsWhy Sampling Is Slow (and How to Make It Faster)
This loop runs the network times per batch of samples. That is the fundamental cost of DDPM sampling and is the largest practical drawback of vanilla diffusion. Two families of improvements:
- DDIM (Song et al., 2021). A deterministic sampler that uses the same trained model but can skip timesteps, often taking to network evaluations instead of with minimal quality loss. Worth its own post.
- Latent diffusion (Rombach et al., 2022). Do the diffusion process in a lower-dimensional latent space learned by a VAE, rather than on raw pixels. Each step is cheaper and you need fewer of them.
For MNIST at , the naive sampler is fine.
Putting It All Together
if __name__ == "__main__":
torch.manual_seed(0)
model, ema, sched = train(epochs=30)
imgs = sample_from_ema(model, ema, sched, n=64)
# imgs is in [-1, 1]; rescale to [0, 1] for display
imgs = (imgs + 1) / 2
from torchvision.utils import save_image
save_image(imgs, "samples.png", nrow=8)After 30 epochs you should get something like clean, legible MNIST digits. If the samples look like noise or mostly uniform gray, the usual suspects are:
- You forgot the
[-1, 1]normalization; the model trains against noise, so unscaled pixels will produce visually dim samples. - You are sampling from the live model instead of EMA (quality will be visibly worse, especially early in training).
- You are gathering by but forgot the
.view(-1, 1, 1, 1)broadcast shape, so the schedule coefficients broadcast wrong. - You swapped the sign somewhere: the formula is correct (we are removing predicted noise), not .
How Each Piece Maps to Part 1
To make the correspondence explicit:
| Math in Part 1 | Code in Part 2 |
|---|---|
Schedule / make_schedule | |
posterior_variance | |
q_sample | |
UNet.forward | |
F.mse_loss(noise_pred, noise) in train_step | |
mean = sqrt_recip_alphas[t] * (x - coef * eps) in sample | |
x = mean + sqrt(posterior_variance[t]) * noise |
If any of these rows confuses you, go back to the corresponding section of Part 1; the translation is line-for-line.
What's Next
With this scaffolding, natural next steps are:
- Better schedules. The cosine schedule of Nichol and Dhariwal (2021) keeps more signal at high and trains faster.
- Classifier-free guidance. The one trick that made text-to-image diffusion actually work. A tiny change to training (drop the conditioning with 10% probability) and sampling (combine conditional and unconditional predictions).
- Faster samplers. DDIM, DPM-Solver, and heun-style second-order integrators get you from 1000 steps down to 10 to 50 without retraining.
- Scaling up. CIFAR-10 or CelebA : increase UNet channels, add self-attention at resolution, train longer. The code changes are modest; the training budget is not.
The fundamentals do not change. Every diffusion model you will encounter, from Stable Diffusion to video diffusion to molecular generation, is a variation on the two posts in this series: an ELBO that reduces to a noise-prediction MSE, and a UNet (or Transformer) trained to minimize it.
Written by Suchinthaka Wanninayaka
AI/ML Researcher exploring semantic communications, diffusion models, and language model systems. Writing about deep learning from theory to production.
Continue the Series
7 RAG Retrieval Strategies, Benchmarked
12 min read
Next ArticleDiffusion Deep Dive Part 2: DDIM — From 1000 Steps to 25 Without Retraining
11 min read
Related Articles
Diffusion Deep Dive Part 1: From an Impossible Integral to a Two-Line Loss (and Back Out to Samples)
22 min read
Diffusion Deep Dive Part 2: DDIM — From 1000 Steps to 25 Without Retraining
11 min read
7 RAG Retrieval Strategies, Benchmarked
12 min read
Responses
No responses yet. Be the first to share your thoughts!