Diffusion Deep Dive Part 2: DDIM — From 1000 Steps to 25 Without Retraining

SW
Suchinthaka W.11 min read
diffusiongenerative-modelsdeep-learningddimddpmode-samplersmachine-learning

Part 1 ended on a slightly unsatisfying note: we trained a network with a clean two-line loss, but sampling still required running it T1000T \approx 1000 times in sequence. For a single image on a single GPU that is seconds; for a batch of conditional text-to-image samples with a large UNet it is minutes.

DDIM (Song, Meng, Ermon, ICLR 2021) fixes this. The headline result:

  • Same trained network. DDIM does not change the training objective in (48)(48) of Part 1. You take any network already trained as a DDPM.
  • A family of samplers parameterized by η\eta. η=1\eta = 1 recovers DDPM. η=0\eta = 0 is deterministic. Everything in between is a continuous interpolation.
  • Sub-sampling timesteps. Because the construction is non-Markov, you can run the reverse process on any subsequence τ1<τ2<<τS\tau_1 < \tau_2 < \cdots < \tau_S of {1,,T}\{1, \ldots, T\}. With S=25S = 25 or 5050 you get samples indistinguishable from the T=1000T = 1000 DDPM ones.

This post derives all of that. If you have Part 1 in your head, there are no new tricks; it is one clean Bayes-rule construction plus one choice of variance.

Reference: Song, Meng, Ermon, "Denoising Diffusion Implicit Models" (ICLR 2021).

Notation Recap

Everything is the same as Part 1:

SymbolMeaning
x0\mathbf{x}_0Clean data sample
αt=1βt\alpha_t = 1 - \beta_t,   αˉt=s=1tαs\;\bar\alpha_t = \prod_{s=1}^{t}\alpha_sNoise schedule (shared with DDPM)
ϵθ(xt,t)\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)The network trained with the DDPM loss (48)(48)
q(xtx0)=N(αˉtx0,(1αˉt)I)q(\mathbf{x}_t \mid \mathbf{x}_0) = \mathcal{N}(\sqrt{\bar\alpha_t}\mathbf{x}_0, (1-\bar\alpha_t)\mathbf{I})Forward marginal, from (24)(24) of Part 1

DDIM introduces two new pieces:

SymbolMeaning
σt\sigma_tFree parameter: noise level injected at the reverse step
η[0,1]\eta \in [0, 1]Interpolation knob: η=1\eta{=}1 DDPM (stochastic), η=0\eta{=}0 deterministic
τ=(τ1,,τS)\tau = (\tau_1, \ldots, \tau_S)Sub-sampled timestep schedule of length STS \ll T

All equations are numbered (1)-(N) in order of appearance.

1. The Key Insight: Training Uses Marginals, Not the Chain

Re-read the DDPM training loss from Part 1, equation (48)(48):

Lsimple(θ)=Et,x0,ϵ ⁣[ϵϵθ ⁣(αˉtx0+1αˉtϵ,  t)2].(1)\mathcal{L}_{\mathrm{simple}}(\theta) = \mathbb{E}_{t,\,\mathbf{x}_0,\,\boldsymbol{\epsilon}}\!\left[ \big\| \boldsymbol{\epsilon} - \boldsymbol{\epsilon}_\theta\!\big(\sqrt{\bar\alpha_t}\,\mathbf{x}_0 + \sqrt{1-\bar\alpha_t}\,\boldsymbol{\epsilon},\; t\big) \big\|^2 \right]. \tag{1}

Notice what does not appear: the forward chain q(xtxt1)q(\mathbf{x}_t \mid \mathbf{x}_{t-1}). The loss only references

xt=αˉtx0+1αˉtϵ,(2)\mathbf{x}_t = \sqrt{\bar\alpha_t}\,\mathbf{x}_0 + \sqrt{1-\bar\alpha_t}\,\boldsymbol{\epsilon}, \tag{2}

which is just the marginal q(xtx0)q(\mathbf{x}_t \mid \mathbf{x}_0). The full chain of transitions is a construction we used to derive the ELBO, but once the network is trained, only (2)(2) matters.

Intuition. DDPM's forward chain is one particular way to produce a sample with the right marginal q(xtx0)q(\mathbf{x}_t \mid \mathbf{x}_0). Nothing forces us to invert that exact chain. Any process whose reverse transitions also give samples with marginal q(xtx0)q(\mathbf{x}_t \mid \mathbf{x}_0) is fair game, and can reuse the same ϵθ\boldsymbol{\epsilon}_\theta. DDIM exploits this freedom.

2. A Non-Markov Forward Process with the Same Marginals

Song et al. define a family of joint distributions qσ(x1:Tx0)q_\sigma(\mathbf{x}_{1:T} \mid \mathbf{x}_0), parameterized by a sequence σ=(σ1,,σT)\sigma = (\sigma_1, \ldots, \sigma_T) of non-negative variances, such that:

  1. The marginals still agree with DDPM: qσ(xtx0)=N(xt;αˉtx0,(1αˉt)I)q_\sigma(\mathbf{x}_t \mid \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_t; \sqrt{\bar\alpha_t}\mathbf{x}_0, (1-\bar\alpha_t)\mathbf{I}).
  2. The x0\mathbf{x}_0-conditioned reverse qσ(xt1xt,x0)q_\sigma(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0) has a specific Gaussian form we will construct.

The construction is neat: fix the marginals first, then pick qσ(xt1xt,x0)q_\sigma(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0), and let the rest of the joint follow.

The defining posterior. For t>1t > 1,

qσ(xt1xt,x0)=N ⁣(xt1;  αˉt1x0  +  1αˉt1σt2xtαˉtx01αˉt,  σt2I).(3)q_\sigma(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}\!\Bigg(\mathbf{x}_{t-1};\; \sqrt{\bar\alpha_{t-1}}\,\mathbf{x}_0 \;+\; \sqrt{1-\bar\alpha_{t-1}-\sigma_t^2}\,\cdot\,\frac{\mathbf{x}_t - \sqrt{\bar\alpha_t}\,\mathbf{x}_0}{\sqrt{1-\bar\alpha_t}},\; \sigma_t^2 \mathbf{I} \Bigg). \tag{3}

The constraint σt21αˉt1\sigma_t^2 \leq 1 - \bar\alpha_{t-1} must hold for the square root to be real.

Why this mean? The term (xtαˉtx0)/1αˉt(\mathbf{x}_t - \sqrt{\bar\alpha_t}\mathbf{x}_0)/\sqrt{1-\bar\alpha_t} is just the standardized noise that took x0\mathbf{x}_0 to xt\mathbf{x}_t; call it ϵt\boldsymbol{\epsilon}_t. Then (3)(3) rewrites as

xt1=αˉt1x0  +  1αˉt1σt2ϵt  +  σtz,(4)\mathbf{x}_{t-1} = \sqrt{\bar\alpha_{t-1}}\,\mathbf{x}_0 \;+\; \sqrt{1-\bar\alpha_{t-1}-\sigma_t^2}\,\boldsymbol{\epsilon}_t \;+\; \sigma_t \mathbf{z}, \tag{4}

with zN(0,I)\mathbf{z} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}). Check the total variance: (1αˉt1σt2)+σt2=1αˉt1(1-\bar\alpha_{t-1}-\sigma_t^2) + \sigma_t^2 = 1-\bar\alpha_{t-1}, which exactly recovers the marginal variance we need for qσ(xt1x0)=N(αˉt1x0,(1αˉt1)I)q_\sigma(\mathbf{x}_{t-1} \mid \mathbf{x}_0) = \mathcal{N}(\sqrt{\bar\alpha_{t-1}}\mathbf{x}_0, (1-\bar\alpha_{t-1})\mathbf{I}). The construction is designed to preserve this by inspection.

Intuition. The mean in (3)(3) is "land on αˉt1x0\sqrt{\bar\alpha_{t-1}}\mathbf{x}_0, then push partway along the exact direction of the noise we observed". That direction is unit-length (after normalization), so we can split the remaining variance freely between deterministic drift along the noise and fresh Gaussian noise. σt\sigma_t controls the split.

3. The Reverse Step in Terms of ϵθ\boldsymbol{\epsilon}_\theta

At inference we do not know x0\mathbf{x}_0. The DDPM trick (Part 1, equation (42)(42)) lets us predict it from xt\mathbf{x}_t and the network:

x^0(xt,t)=xt1αˉtϵθ(xt,t)αˉt.(5)\hat{\mathbf{x}}_0(\mathbf{x}_t, t) = \frac{\mathbf{x}_t - \sqrt{1-\bar\alpha_t}\,\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)}{\sqrt{\bar\alpha_t}}. \tag{5}

Substitute (5)(5) for x0\mathbf{x}_0 in (4)(4). The standardized noise ϵt=(xtαˉtx^0)/1αˉt\boldsymbol{\epsilon}_t = (\mathbf{x}_t - \sqrt{\bar\alpha_t}\hat{\mathbf{x}}_0)/\sqrt{1-\bar\alpha_t} becomes exactly ϵθ(xt,t)\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t) (you can verify this by plugging in (5)(5)). We land on the DDIM reverse update:

  xt1  =  αˉt1x^0(xt,t)  +  1αˉt1σt2ϵθ(xt,t)  +  σtz.  (6)\boxed{\; \mathbf{x}_{t-1} \;=\; \sqrt{\bar\alpha_{t-1}}\,\hat{\mathbf{x}}_0(\mathbf{x}_t, t) \;+\; \sqrt{1-\bar\alpha_{t-1}-\sigma_t^2}\,\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t) \;+\; \sigma_t\,\mathbf{z}. \;} \tag{6}

Three pieces, each physically meaningful:

  • αˉt1x^0\sqrt{\bar\alpha_{t-1}}\,\hat{\mathbf{x}}_0: "predict the clean image, then renoise it to the signal level required at t1t-1".
  • 1αˉt1σt2ϵθ\sqrt{1-\bar\alpha_{t-1}-\sigma_t^2}\,\boldsymbol{\epsilon}_\theta: "push partway along the network's estimated noise direction".
  • σtz\sigma_t\,\mathbf{z}: "add whatever Gaussian noise is needed to top up the variance to 1αˉt11-\bar\alpha_{t-1}".

Equation (6)(6) is the whole sampler, except we still need to pick σt\sigma_t.

4. The η\eta Parameter: Stochastic vs Deterministic

Song et al. parameterize σt\sigma_t as

σt(η)  =  η1αˉt11αˉt1αˉtαˉt1,η[0,1].(7)\sigma_t(\eta) \;=\; \eta \cdot \sqrt{\frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}}\,\sqrt{1 - \frac{\bar\alpha_t}{\bar\alpha_{t-1}}}, \qquad \eta \in [0, 1]. \tag{7}

Two cases matter.

4.1. η=1\eta = 1: Recovers DDPM

Plugging η=1\eta = 1 into (7)(7) and simplifying using αˉt/αˉt1=αt\bar\alpha_t / \bar\alpha_{t-1} = \alpha_t and 1αt=βt1 - \alpha_t = \beta_t:

σt2(1)  =  1αˉt11αˉt(1αt)  =  1αˉt11αˉtβt  =  β~t.(8)\sigma_t^2(1) \;=\; \frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\,(1 - \alpha_t) \;=\; \frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\,\beta_t \;=\; \tilde\beta_t. \tag{8}

This is exactly the DDPM posterior variance β~t\tilde\beta_t from Part 1 equation (27)(27). And the mean in (6)(6) at η=1\eta = 1 matches the DDPM reverse mean from Part 1 equation (44)(44) (a couple of lines of algebra). So DDIM with η=1\eta = 1 is DDPM. Nothing has been lost.

4.2. η=0\eta = 0: Deterministic Sampling

Plugging η=0\eta = 0 into (7)(7) gives σt=0\sigma_t = 0, and (6)(6) collapses to

xt1  =  αˉt1x^0(xt,t)  +  1αˉt1ϵθ(xt,t).(9)\mathbf{x}_{t-1} \;=\; \sqrt{\bar\alpha_{t-1}}\,\hat{\mathbf{x}}_0(\mathbf{x}_t, t) \;+\; \sqrt{1-\bar\alpha_{t-1}}\,\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t). \tag{9}

No noise injection. Given a fixed xT\mathbf{x}_T, the trajectory xTxT1x0\mathbf{x}_T \to \mathbf{x}_{T-1} \to \cdots \to \mathbf{x}_0 is a deterministic function of xT\mathbf{x}_T and θ\theta. This is what the community calls "DDIM sampling" in the narrow sense.

Intuition. At η=1\eta = 1 each reverse step is a noisy nudge toward cleaner data. At η=0\eta = 0 each step is a crisp projection: "if the current noise direction is ϵθ\boldsymbol{\epsilon}_\theta, walk deterministically along it to the next signal level". The η(0,1)\eta \in (0, 1) regime is a smooth family connecting the two.

5. Sub-Sampling Timesteps: The Actual Speed-Up

Here is the payoff. The construction of qσq_\sigma in Section 2 is non-Markov in x1:T\mathbf{x}_{1:T}: the transition qσ(xt1xt,x0)q_\sigma(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0) was defined to preserve the marginals, without reference to the chain order. Consequently, the update (6)(6) does not require consecutive timesteps.

Pick any strictly increasing subsequence

τ=(τ0,τ1,,τS),τ0=0,  τS=T,  ST.(10)\tau = (\tau_0, \tau_1, \ldots, \tau_S), \qquad \tau_0 = 0,\; \tau_S = T,\; S \ll T. \tag{10}

Then run the reverse update on this sparse grid:

xτi1  =  αˉτi1x^0(xτi,τi)  +  1αˉτi1στi2ϵθ(xτi,τi)  +  στiz.(11)\mathbf{x}_{\tau_{i-1}} \;=\; \sqrt{\bar\alpha_{\tau_{i-1}}}\,\hat{\mathbf{x}}_0(\mathbf{x}_{\tau_i}, \tau_i) \;+\; \sqrt{1-\bar\alpha_{\tau_{i-1}}-\sigma_{\tau_i}^2}\,\boldsymbol{\epsilon}_\theta(\mathbf{x}_{\tau_i}, \tau_i) \;+\; \sigma_{\tau_i}\,\mathbf{z}. \tag{11}

We call the network SS times instead of TT times. Typical choices:

  • τi=iT/S\tau_i = \lfloor i \cdot T / S \rfloor (linear striding), or
  • τi=i2T/S2\tau_i = \lfloor i^2 \cdot T / S^2 \rfloor (quadratic striding, which concentrates early steps near TT).

With S=50S = 50 and η=0\eta = 0 on CIFAR-10, Song et al. report FID competitive with T=1000T = 1000 DDPM. On ImageNet and CelebA the story is similar.

Intuition. DDPM's reverse chain is Markov, which tied us to consecutive timesteps, which meant TT network calls. DDIM's reverse chain is non-Markov in the latents (it always conditions on the predicted x^0\hat{\mathbf{x}}_0), which decouples the chain from the grid and lets us choose a much coarser grid without re-deriving anything.

6. Connection to the Probability-Flow ODE

When η=0\eta = 0, equation (9)(9) is a discretization of a deterministic differential equation. Define the continuous time t[0,1]t \in [0, 1] via αˉt=\bar\alpha_t = some smooth schedule, and let dt0\mathrm{d}t \to 0. The DDIM update can be rewritten as an Euler step for

dx  =  12g(t)2xlogpt(x)dt,(12)\mathrm{d}\mathbf{x} \;=\; -\tfrac{1}{2}\,g(t)^2\,\nabla_{\mathbf{x}} \log p_t(\mathbf{x})\,\mathrm{d}t, \tag{12}

where g(t)2g(t)^2 is the schedule's diffusion coefficient and xlogpt(x)\nabla_{\mathbf{x}} \log p_t(\mathbf{x}) is the score of the noisy marginal at time tt. This ODE is the probability-flow ODE of Song et al. (2021, "Score-Based Generative Modeling Through Stochastic Differential Equations"). The score and the noise prediction are related by

xlogpt(x)  =  ϵθ(x,t)1αˉt,(13)\nabla_{\mathbf{x}} \log p_t(\mathbf{x}) \;=\; -\frac{\boldsymbol{\epsilon}_\theta(\mathbf{x}, t)}{\sqrt{1-\bar\alpha_t}}, \tag{13}

so ϵθ\boldsymbol{\epsilon}_\theta is, up to a scalar, a score network. This reframing has two consequences:

  • Any ODE solver applies. Deterministic DDIM is Euler; higher-order solvers (Heun, DPM-Solver, PLMS) give the same or better quality with even fewer steps (often 1010 to 2020).
  • The latent is informative. Because the map xTx0\mathbf{x}_T \to \mathbf{x}_0 is deterministic and smooth, latents are meaningfully interpolable.

7. The DDIM Sampling Algorithm

Algorithm 3DDIM sampling (Song, Meng, Ermon 2021)

Four things are worth noting:

  1. ϵθ\boldsymbol{\epsilon}_\theta is called exactly SS times, not TT.
  2. At η=0\eta = 0 the σz\sigma \cdot \mathbf{z} term vanishes, so you can skip sampling z\mathbf{z}.
  3. At η=1\eta = 1 and τi=i\tau_i = i (no sub-sampling) this algorithm is bit-identical to DDPM's Algorithm 2 from Part 1.
  4. x^0\hat{\mathbf{x}}_0 is useful to clamp to the data range (e.g. [1,1][-1, 1]) at each step for stability; this is a common trick with no theoretical cost.

8. What Deterministic DDIM Enables

Because η=0\eta = 0 sampling is a deterministic invertible map xTx0\mathbf{x}_T \to \mathbf{x}_0, three non-obvious things become possible.

(a) Image encoding. Given a real image x0\mathbf{x}_0^\star, the reverse-direction Euler step of (9)(9) encodes it back to a latent xT\mathbf{x}_T^\star such that running DDIM forward on xT\mathbf{x}_T^\star reconstructs (to numerical precision) the original image. This is how DDIM Inversion works.

(b) Semantic interpolation. Interpolating two real images is often disappointing: mixing pixels gives ghosty results. Instead, DDIM-invert both to latents xT(a),xT(b)\mathbf{x}_T^{(a)}, \mathbf{x}_T^{(b)} and slerp (spherical linear interpolation) between them, then DDIM-sample. The intermediate outputs cross through plausible, in-distribution samples.

(c) Deterministic seeds. Fixing xT\mathbf{x}_T fixes the output entirely. This is the basis for A/B testing prompts under classifier-free guidance (same seed, different text conditioning, directly comparable outputs) in text-to-image systems.

The stochastic DDPM sampler cannot do any of these, because the path from xT\mathbf{x}_T to x0\mathbf{x}_0 integrates fresh Gaussian noise at every step.

9. Summary

QuantityDDPM ((51)(51) in Part 1)DDIM η=0\eta{=}0
Training lossLsimple\mathcal{L}_{\mathrm{simple}} (48)(48)same
Network ϵθ\boldsymbol{\epsilon}_\thetarequiredsame
Reverse stepstochastic, Markovdeterministic, non-Markov in latents
Steps needed1000\sim 10002525 to 5050 (or 1010 to 2020 with better ODE solvers)
xTx0\mathbf{x}_T \to \mathbf{x}_0 mapmany-to-many (noise injected at each step)invertible bijection
Interpolation / inversionnoyes

The path from DDPM to DDIM in one sentence: redefine the forward process as non-Markov while keeping the marginals, and a family of samplers parameterized by η\eta falls out, with η=0\eta = 0 being a fast deterministic solver for the same trained network.

What's Next

In Part 3 we build a DDPM end-to-end in PyTorch: schedule, UNet with time embeddings, training loop from (48)(48) of Part 1, and the reverse sampler (51)(51). Swapping in the DDIM sampler from this post is a ten-line change: replace the sampling loop with the algorithm above, keep ϵθ\boldsymbol{\epsilon}_\theta untouched, and set S=50S = 50, η=0\eta = 0. Same network, twenty times faster.

Share:
SW

Written by Suchinthaka Wanninayaka

AI/ML Researcher exploring semantic communications, diffusion models, and language model systems. Writing about deep learning from theory to production.

Responses

?

No responses yet. Be the first to share your thoughts!