Transformer Deep Dive: Part 5 - Training Improvements
Training a large language model is one of the most resource-intensive computational tasks in modern AI. A single GPT-4 scale training run can cost tens of millions of dollars and consume megawatt-hours of electricity over several months. At this scale, every percentage point of efficiency matters, and the difference between a stable and unstable training run can mean millions of dollars saved or wasted.
In this post, we examine the key techniques that make modern LLM training feasible: optimizers designed for the scale of transformer parameters, learning rate schedules that guide convergence, mixed precision arithmetic that doubles throughput without sacrificing model quality, gradient checkpointing to fit larger models in memory, gradient clipping for stability, and distributed training strategies that coordinate hundreds or thousands of GPUs.
The Training Loop
At its core, LLM training follows the standard supervised learning loop: forward pass to compute the loss, backward pass to compute gradients, and an optimizer step to update parameters. But the details at transformer scale are anything but standard.
The canonical training step updates parameters according to:
where is a time-varying learning rate (governed by a schedule), is the gradient of the language modeling loss (typically cross-entropy over the vocabulary), and is the optimizer-specific transformation of the raw gradient.
For autoregressive language models, the loss on a sequence is the average negative log-likelihood of next-token prediction:
A modern training step involves far more than this equation suggests. The forward pass runs in reduced precision (BF16), gradients are accumulated across micro-batches, clipped to prevent explosions, then synchronized across hundreds of GPUs before the optimizer applies its update in FP32.
import torch
from torch.cuda.amp import autocast
def training_step(model, batch, optimizer, scaler, grad_accum_steps, max_grad_norm):
"""A single training step with mixed precision, gradient accumulation, and clipping."""
optimizer.zero_grad()
total_loss = 0.0
for micro_step in range(grad_accum_steps):
micro_batch = batch[micro_step]
with autocast(dtype=torch.bfloat16):
logits = model(micro_batch["input_ids"])
loss = torch.nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)),
micro_batch["labels"].view(-1),
ignore_index=-100,
)
loss = loss / grad_accum_steps # Normalize by accumulation steps
loss.backward()
total_loss += loss.item()
# Gradient clipping before optimizer step
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
optimizer.step()
return total_loss
Optimizers: From SGD to AdamW
Vanilla SGD and Its Limitations
Stochastic Gradient Descent updates parameters by subtracting the gradient scaled by a learning rate: . This works well for convex problems but struggles with the highly non-convex loss landscapes of transformers. The gradient can oscillate wildly across different parameter groups --- embedding matrices may have gradients many orders of magnitude larger than attention weight matrices --- making a single global learning rate inadequate.
Adam: Adaptive Moment Estimation
Adam (Kingma & Ba, 2015) addresses this by maintaining per-parameter running estimates of the first moment (mean) and second moment (uncentered variance) of the gradient:
Because and are initialized at zero, they are biased toward zero during early training. The bias-corrected estimates are:
The update divides the first moment by the square root of the second moment, effectively giving each parameter its own adaptive learning rate:
Standard values are , , . For LLM training, is increasingly common (as used in LLaMA and GPT-3), which makes the variance estimate more responsive to recent gradients and can improve stability in later training stages.
AdamW: Why L2 Regularization Fails in Adam
Loshchilov & Hutter (2019) identified a subtle but critical flaw in how Adam handles weight decay. To understand it, consider how L2 regularization is typically implemented: the loss becomes , so the gradient becomes .
In vanilla SGD, applying L2 regularization through the gradient is mathematically equivalent to applying weight decay directly. The SGD update with L2 is:
This is identical to subtracting directly from the parameters (weight decay). However, in Adam, the regularization gradient gets divided by :
This means parameters with large historical gradients (large ) receive less effective regularization, while parameters with small gradients receive more. This is the opposite of what we want --- large, active parameters should arguably be regularized more strongly. The adaptive scaling that makes Adam effective for optimization actively undermines the regularization.
AdamW fixes this by applying weight decay directly to the parameters, completely bypassing the adaptive scaling:
Note the key difference: weight decay () is added after the adaptive division, not before. Every parameter now receives the same proportional decay regardless of its gradient history.
class AdamW:
"""AdamW optimizer with decoupled weight decay.
The key difference from Adam: weight decay is applied directly to
parameters, not through the gradient (which would be scaled by the
adaptive learning rate).
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.95),
eps=1e-8, weight_decay=0.1):
self.params = list(params)
self.lr = lr
self.beta1, self.beta2 = betas
self.eps = eps
self.weight_decay = weight_decay
self.t = 0
# Initialize moment buffers
self.m = [torch.zeros_like(p) for p in self.params]
self.v = [torch.zeros_like(p) for p in self.params]
def step(self):
self.t += 1
for i, param in enumerate(self.params):
if param.grad is None:
continue
grad = param.grad.data
# Update biased first and second moment estimates
self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * grad
self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * grad ** 2
# Bias correction
m_hat = self.m[i] / (1 - self.beta1 ** self.t)
v_hat = self.v[i] / (1 - self.beta2 ** self.t)
# AdamW update: weight decay is DECOUPLED from the adaptive update
param.data -= self.lr * (
m_hat / (v_hat.sqrt() + self.eps) # Adaptive gradient step
+ self.weight_decay * param.data # Direct weight decay
)
Optimizer Memory Overhead
Every optimizer state variable costs memory proportional to model size. For a 70B parameter model in FP32, each extra state tensor adds 280 GB:
| Optimizer | States per Parameter | Extra Memory per Param | Notes |
|---|---|---|---|
| SGD | 0 | 0 bytes | Simplest, but requires careful LR tuning |
| SGD + Momentum | 1 (momentum buffer) | 4 bytes | More stable convergence |
| Adam / AdamW | 2 (, ) | 8 bytes | Adaptive LR, standard for LLMs |
| Adafactor | ~1 (factored ) | ~4 bytes | Factorizes second moment as outer product |
| Lion | 1 (momentum buffer) | 4 bytes | Sign-based update, competitive results |
| 8-bit Adam | 2 (quantized) | 2 bytes | Quantized optimizer states |
For a 70B model, AdamW requires ~560 GB of optimizer state alone (two FP32 tensors), which is a primary motivation for techniques like FSDP and 8-bit optimizers.
Learning Rate Schedules
The learning rate is arguably the single most important hyperparameter in LLM training. Modern practice universally uses schedules that combine warmup with a decay phase.
Linear Warmup
Warmup is essential for transformer training stability. During the first steps, the learning rate increases linearly from near-zero to the peak value:
Why is warmup necessary? Three factors converge:
-
Adam's variance estimate is unreliable early on. The second moment is initialized to zero and takes many steps to reflect the true gradient variance. With a large learning rate, the denominator is artificially small, causing extremely large updates that can destabilize training or push the model into a bad loss basin.
-
Layer normalization gradients are large initially. Before the model learns meaningful representations, gradients through layer normalization can be very noisy. Warmup gives the normalization statistics time to stabilize.
-
Embedding matrices see sparse, high-magnitude gradients. Only a small subset of tokens appears in each batch, but those token embeddings receive concentrated gradient updates. Warmup prevents these few vectors from being pushed too far before the model has a chance to learn distributional patterns.
Typical warmup durations are 0.1-2% of total training steps. LLaMA uses 2000 steps; GPT-3 uses warmup over the first 375 million tokens.
Cosine Decay
After warmup, the learning rate follows a cosine curve from down to :
where is the warmup period and is the total training duration. Cosine decay provides a smooth, gradual reduction that spends more time at moderate learning rates than linear decay. Most LLMs set .
Linear Decay
A simpler alternative that decreases the learning rate at a constant rate:
Linear decay is less common for large-scale pretraining but still used in some fine-tuning scenarios.
Warmup-Stable-Decay (WSD)
An emerging schedule used in recent models (e.g., MiniCPM). WSD has three phases:
- Warmup: Linear ramp-up (same as above).
- Stable: Hold at peak learning rate for the majority of training.
- Decay: Rapid cosine or exponential decay in the final phase (typically last 10-20%).
The insight behind WSD is that a constant learning rate during the middle phase enables the model to explore the loss landscape more freely, while the rapid final decay allows it to settle into a sharp minimum. This can also make it easier to resume training or extend the training budget without restarting the schedule from scratch.
import math
from torch.optim.lr_scheduler import LambdaLR
def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps,
min_lr_ratio=0.1):
"""Cosine decay with linear warmup, used by LLaMA, GPT-3, etc."""
def lr_lambda(step):
# Linear warmup
if step < warmup_steps:
return step / max(1, warmup_steps)
# Cosine decay to min_lr_ratio of peak
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay
return LambdaLR(optimizer, lr_lambda)
def get_wsd_schedule(optimizer, warmup_steps, stable_steps, decay_steps,
min_lr_ratio=0.0):
"""Warmup-Stable-Decay schedule as used in MiniCPM."""
total_steps = warmup_steps + stable_steps + decay_steps
def lr_lambda(step):
if step < warmup_steps:
return step / max(1, warmup_steps)
elif step < warmup_steps + stable_steps:
return 1.0
else:
decay_progress = (step - warmup_steps - stable_steps) / max(1, decay_steps)
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * decay_progress))
return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay
return LambdaLR(optimizer, lr_lambda)
Training Recipes from Real Models
The following table summarizes hyperparameter choices from published LLM training runs:
| Hyperparameter | GPT-3 (175B) | LLaMA 2 (70B) | Chinchilla (70B) | Mistral (7B) |
|---|---|---|---|---|
| Optimizer | Adam | AdamW | AdamW | AdamW |
| Peak LR | 0.6e-4 | 1.5e-4 | 1.0e-4 | 3.0e-4 |
| 0.9, 0.95 | 0.9, 0.95 | 0.9, 0.95 | 0.9, 0.95 | |
| Weight Decay | 0.1 | 0.1 | 0.1 | 0.1 |
| Warmup | 375M tokens | 2000 steps | 1500 steps | 1000 steps |
| LR Schedule | Cosine | Cosine | Cosine | Cosine |
| Final LR | 10% of peak | 10% of peak | 10% of peak | 10% of peak |
| Batch Size (tokens) | 3.2M | 4M | 1.5M | 4M |
| Gradient Clipping | 1.0 | 1.0 | 1.0 | 1.0 |
| Total Tokens | 300B | 2T | 1.4T | ~8T (estimated) |
The convergence across different labs and model scales is striking --- the community has largely settled on AdamW with , cosine decay to 10% of peak, and gradient clipping at 1.0.
Gradient Clipping
Gradient clipping is a simple but essential technique for training stability. Without it, a single bad batch can produce enormous gradients that corrupt learned parameters and cause loss spikes from which the model may never recover.
The standard approach is global norm clipping. First, compute the global gradient norm across all parameters:
If this exceeds a threshold (typically ), scale all gradients down proportionally:
This preserves the direction of the gradient update while limiting its magnitude. The clipping threshold of 1.0 is nearly universal across LLM training runs, and gradient norm monitoring is one of the most important training diagnostics --- a sustained increase in gradient norms often precedes a loss spike or divergence.
def clip_grad_norm_(parameters, max_norm=1.0):
"""Global gradient norm clipping (simplified version of PyTorch's implementation)."""
parameters = [p for p in parameters if p.grad is not None]
# Compute global norm
total_norm_sq = sum(p.grad.data.norm() ** 2 for p in parameters)
total_norm = total_norm_sq.sqrt()
# Scale gradients if norm exceeds threshold
clip_coef = max_norm / max(total_norm, max_norm)
for p in parameters:
p.grad.data.mul_(clip_coef)
return total_norm
Mixed Precision Training
The Numerical Precision Landscape
Modern GPUs offer several floating-point formats, each with different tradeoffs between range, precision, and throughput:
| Format | Total Bits | Sign | Exponent | Mantissa | Dynamic Range | Use Case |
|---|---|---|---|---|---|---|
| FP32 | 32 | 1 | 8 | 23 | Master weights, optimizer states | |
| TF32 | 19 | 1 | 8 | 10 | NVIDIA Tensor Core internal | |
| BF16 | 16 | 1 | 8 | 7 | Forward/backward pass | |
| FP16 | 16 | 1 | 5 | 10 | Forward/backward (with loss scaling) | |
| FP8 (E4M3) | 8 | 1 | 4 | 3 | Emerging for training | |
| FP8 (E5M2) | 8 | 1 | 5 | 2 | Emerging for gradients |
BF16 vs FP16: Why BF16 Won
The choice between BF16 and FP16 comes down to range vs. precision:
FP16 allocates 5 bits to the exponent and 10 to the mantissa. It has good precision (roughly 3 decimal digits) but a maximum representable value of only 65504. During transformer training, loss values and intermediate activations can easily exceed this range. When they do, the result is either infinity or NaN, and training diverges. The workaround is loss scaling: multiply the loss by a large constant before the backward pass (to push small gradients into representable range), then divide the gradients back down before the optimizer step. A GradScaler dynamically adjusts this scale factor, reducing it whenever overflow is detected.
BF16 allocates 8 bits to the exponent (same as FP32) and only 7 to the mantissa. This gives it the same dynamic range as FP32 () at the cost of lower precision (roughly 2 decimal digits). The key insight is that for neural network training, range matters more than precision. Gradients vary over many orders of magnitude, and BF16 handles this natively without any loss scaling. The reduced precision is acceptable because the stochasticity of SGD already introduces noise far larger than BF16 rounding errors.
In practice, BF16 has become the standard for LLM training because it eliminates the fragile loss-scaling machinery, simplifies the training code, and produces equivalent final model quality.
The Mixed Precision Strategy
Even with BF16, certain operations must remain in FP32 to maintain numerical stability:
- Master weights are stored in FP32. The optimizer updates are computed in FP32 and applied to these master weights.
- Forward pass runs in BF16. A BF16 copy of the weights is used for the forward computation. On modern GPUs, this runs at 2x the throughput of FP32 on Tensor Cores.
- Backward pass runs in BF16. Gradients are computed in reduced precision.
- Gradient accumulation is done in FP32. When accumulating gradients across micro-batches, the accumulated buffer must be FP32 to avoid precision loss from repeated additions of small values.
- Optimizer step operates in FP32. The moment estimates (, ) and the parameter update are computed in full precision.
Certain numerical operations should always use FP32 regardless of the global precision setting: softmax (exponentiation is sensitive to rounding), layer normalization (variance computation), and loss computation (log-probabilities can be very small).
import torch
from torch.cuda.amp import autocast
def mixed_precision_training_loop(model, dataloader, optimizer, num_epochs,
grad_accum_steps=4, max_grad_norm=1.0):
"""Full mixed precision training loop with BF16.
BF16 does not require GradScaler (unlike FP16), simplifying the code.
"""
model.train()
for epoch in range(num_epochs):
for step, batch in enumerate(dataloader):
is_accumulation_step = (step + 1) % grad_accum_steps != 0
# Forward pass in BF16
with autocast(dtype=torch.bfloat16):
logits = model(batch["input_ids"])
loss = torch.nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)),
batch["labels"].view(-1),
ignore_index=-100,
)
loss = loss / grad_accum_steps
# Backward pass (gradients computed in BF16, accumulated in FP32)
loss.backward()
if not is_accumulation_step:
# Clip gradients (computed on FP32 gradient buffers)
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=max_grad_norm
)
# Optimizer step updates FP32 master weights
optimizer.step()
optimizer.zero_grad()
if step % 100 == 0:
print(f"Step {step}, Loss: {loss.item():.4f}, "
f"Grad Norm: {grad_norm:.4f}")
Memory Savings from Mixed Precision
For a model with parameters, the memory breakdown per parameter is:
| Component | FP32 Only | Mixed Precision (BF16 + FP32) |
|---|---|---|
| Model weights | 4 bytes (FP32) | 2 bytes (BF16) + 4 bytes (FP32 master) |
| Gradients | 4 bytes (FP32) | 2 bytes (BF16) |
| Optimizer states (AdamW) | 8 bytes (2x FP32) | 8 bytes (2x FP32) |
| Total per parameter | 16 bytes | 16 bytes |
| Activations (saved for backward) | FP32 | BF16 (2x reduction) |
The parameter-level savings are modest, but the activation memory savings are substantial. Activations dominate memory for long sequences, and storing them in BF16 cuts that memory in half, enabling longer context lengths or larger batch sizes.
Gradient Checkpointing
The Activation Memory Problem
During the backward pass, we need the activations from the forward pass to compute gradients. Naively, this means storing activations for every layer. For a model with layers, batch size , sequence length , and hidden dimension :
For LLaMA-70B (, ) with , , in BF16: this is roughly GB just for the hidden state activations. Including attention scores, FFN intermediates, and normalization buffers, the actual figure is several times higher and can easily exceed 100 GB.
How Checkpointing Works
Gradient checkpointing (Chen et al., 2016) trades compute for memory. Instead of storing all activations, we designate certain layers as "checkpoints" and only store their outputs. During the backward pass, when we need activations from non-checkpointed layers, we recompute them by running a partial forward pass from the nearest checkpoint.
The most common strategy is to checkpoint every transformer block boundary:
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
class TransformerBlock(nn.Module):
"""A single transformer block (simplified)."""
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
self.attn_norm = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.ff_norm = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model),
)
def forward(self, x):
x = x + self.attn(self.attn_norm(x), self.attn_norm(x), self.attn_norm(x))[0]
x = x + self.ff(self.ff_norm(x))
return x
class CheckpointedTransformerModel(nn.Module):
"""Transformer with gradient checkpointing to reduce activation memory."""
def __init__(self, n_layers, d_model, n_heads, d_ff, use_checkpointing=True):
super().__init__()
self.layers = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff) for _ in range(n_layers)
])
self.use_checkpointing = use_checkpointing
def forward(self, x):
for layer in self.layers:
if self.use_checkpointing and self.training:
# Activations for this layer are NOT stored.
# They will be recomputed during backward pass.
x = checkpoint(layer, x, use_reentrant=False)
else:
x = layer(x)
return x
The Memory-Compute Tradeoff
The tradeoff is clean and predictable. Checkpointing every layer means we only store the input to the model and the final output, reducing activation memory from to , but we recompute every layer's forward pass during backward, roughly doubling the total compute. The optimal strategy checkpoints every layers, achieving memory with approximately 50% compute overhead.
| Strategy | Activation Memory | Forward Compute Overhead | When to Use |
|---|---|---|---|
| No checkpointing | 1.0x | Small models, ample GPU memory | |
| Every layers | ~1.5x | Balanced approach | |
| Every layer | ~2.0x | Maximum memory savings | |
| Selective (attention only) | ~ | ~1.3x | Attention is the bottleneck |
In practice, checkpointing every layer is most common for large model training because the compute overhead (roughly 30-40% in practice, less than the theoretical 2x due to memory bandwidth effects) is acceptable given the memory savings.
Distributed Training
A 70B parameter model in FP32 requires 280 GB just for the weights, far exceeding any single GPU's memory. Even with mixed precision and checkpointing, training at scale requires distributing the computation across many GPUs.
Data Parallelism (DDP)
The simplest distributed strategy is to replicate the model on every GPU and split the data. Each GPU processes a different mini-batch, computes gradients independently, then synchronizes gradients via an all-reduce operation before the optimizer step.
DDP (Distributed Data Parallel) in PyTorch overlaps gradient communication with backward computation: as soon as a layer's gradients are computed, the all-reduce for that layer begins while the backward pass continues through earlier layers.
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_ddp(rank, world_size):
"""Initialize DDP process group."""
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def train_with_ddp(rank, world_size, model, dataloader, optimizer):
setup_ddp(rank, world_size)
model = model.to(rank)
model = DDP(model, device_ids=[rank])
for batch in dataloader:
batch = {k: v.to(rank) for k, v in batch.items()}
loss = model(batch["input_ids"], labels=batch["labels"]).loss
loss.backward() # Gradients all-reduced automatically
optimizer.step()
optimizer.zero_grad()
DDP's limitation: every GPU must hold a full copy of the model. For a 70B model with AdamW, this means ~1.1 TB per GPU (weights + gradients + optimizer states) --- clearly infeasible.
Fully Sharded Data Parallelism (FSDP)
FSDP (Zhao et al., 2023) extends data parallelism by sharding model parameters, gradients, and optimizer states across GPUs. This is conceptually similar to DeepSpeed ZeRO Stage 3.
The key operations in FSDP are:
- All-gather before each layer's forward pass: collect the full parameter tensor from all GPUs.
- Forward computation: Run the layer with the full parameters.
- Discard the gathered parameters after use (only keep the local shard).
- All-gather again during the backward pass, compute gradients.
- Reduce-scatter gradients: each GPU receives the gradient shard corresponding to its parameter shard.
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools
# Define wrapping policy: shard at the transformer block level
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerBlock},
)
# Mixed precision configuration
mixed_precision_policy = MixedPrecision(
param_dtype=torch.bfloat16, # Parameters gathered in BF16
reduce_dtype=torch.float32, # Gradient reduction in FP32
buffer_dtype=torch.bfloat16, # Buffers in BF16
)
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision_policy,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
)
Per-GPU Memory Comparison
The memory reduction from FSDP is dramatic. For a 70B parameter model with AdamW in mixed precision across 8 GPUs:
| Component | DDP (per GPU) | FSDP Full Shard (per GPU) |
|---|---|---|
| Parameters (BF16) | 140 GB | 17.5 GB |
| Gradients (BF16) | 140 GB | 17.5 GB |
| Optimizer states (FP32) | 560 GB | 70 GB |
| FP32 master weights | 280 GB | 35 GB |
| Total | 1,120 GB | 140 GB |
This makes it feasible to train a 70B model on 8x 80GB A100s with FSDP, which would be impossible with DDP.
Tensor Parallelism (TP)
Tensor parallelism (Shoeybi et al., 2019) splits individual matrix operations across GPUs. For a linear layer where , we can partition column-wise across GPUs:
Each GPU computes and the results are concatenated via an all-gather. For the MLP in a transformer block, this is applied to both linear layers with complementary splits (column-parallel for the first, row-parallel for the second) to minimize communication.
Tensor parallelism is typically used with degree 2, 4, or 8 within a single node where GPUs are connected by NVLink (900 GB/s on H100 SXM), because the all-reduce communication at every layer is latency-sensitive.
Pipeline Parallelism (PP)
Pipeline parallelism assigns different groups of layers to different GPUs. A model with 80 layers across 4 GPUs might assign layers 1-20 to GPU 0, 21-40 to GPU 1, and so on.
The naive implementation creates "pipeline bubbles" where GPUs sit idle waiting for activations from earlier stages. Micro-batching (GPipe) and interleaved scheduling (PipeDream) reduce bubble overhead by breaking the batch into smaller micro-batches and overlapping computation.
The bubble fraction for a -stage pipeline with micro-batches is approximately:
With , the bubble overhead becomes negligible.
3D Parallelism
For the largest training runs (thousands of GPUs), all three strategies are combined:
Typical configurations exploit the hardware topology:
- TP within a node (4-8 GPUs connected by NVLink).
- PP across nodes within a rack (fast interconnect).
- DP/FSDP across racks.
Example: training a 175B model on 1024 H100 GPUs might use TP=8 (within each node), PP=4 (across 4 nodes), and DP=32 (32 groups of 4 nodes).
| Parallelism | What is split | Communication | Typical Scale |
|---|---|---|---|
| Data (DDP/FSDP) | Batches (and optionally model state) | All-reduce / reduce-scatter | 8-1000s of GPUs |
| Tensor (TP) | Individual layers/matrices | All-reduce per layer | 2-8 GPUs (within node) |
| Pipeline (PP) | Groups of layers | Point-to-point activations | 2-16 stages |
Putting It All Together: A Training Recipe
Combining everything discussed, here is a summary of the standard configuration for modern LLM pretraining:
| Component | Standard Choice | Rationale |
|---|---|---|
| Optimizer | AdamW (, , ) | Decoupled weight decay; improves late-training stability |
| Weight Decay | 0.1 | Applied to all params except biases and LayerNorm |
| Peak Learning Rate | Scales with model size (e.g., 3e-4 for 7B, 1.5e-4 for 70B) | Smaller models tolerate higher LR |
| Warmup | 1-2% of total steps (1000-2000 steps) | Stabilizes Adam's variance estimate |
| LR Schedule | Cosine decay to 10% of peak | Smooth decay, well-studied empirically |
| Precision | BF16 mixed precision | Same range as FP32, no loss scaling needed |
| Gradient Clipping | Global norm = 1.0 | Prevents loss spikes from bad batches |
| Gradient Checkpointing | Every transformer block | Essential for large models |
| Batch Size | Ramp from small to large (4M tokens typical) | Large batch for throughput, small early for stability |
| Distributed Strategy | FSDP + TP (within node) | Balances memory and communication |
In the next post, we will explore Part 6: Inference Optimization --- the challenges of deploying trained models in production, including KV-cache mechanics, quantization from INT8 to INT4, speculative decoding for faster generation, and continuous batching with PagedAttention.
References
- Kingma, D. P. & Ba, J. (2015). Adam: A Method for Stochastic Optimization. ICLR 2015. arXiv:1412.6980.
- Loshchilov, I. & Hutter, F. (2019). Decoupled Weight Decay Regularization. ICLR 2019. arXiv:1711.05101.
- Chen, T., Xu, B., Zhang, C., & Guestrin, C. (2016). Training Deep Nets with Sublinear Memory Cost. arXiv:1604.06174.
- Shoeybi, M., Patwary, M., Puri, R., LeGresley, P., Casper, J., & Catanzaro, B. (2019). Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism. arXiv:1909.08053.
- Zhao, Y., Gu, A., Varma, R., Luo, L., Huang, C., Xu, M., ... & Chintala, S. (2023). PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel. VLDB 2023. arXiv:2304.11277.
- Touvron, H., Lavril, T., Izacard, G., et al. (2023). LLaMA: Open and Efficient Foundation Language Models. arXiv:2302.13971.
- Touvron, H., Martin, L., Stone, K., et al. (2023). LLaMA 2: Open Foundation and Fine-Tuned Chat Models. arXiv:2307.09288.
- Brown, T. B., Mann, B., Ryder, N., et al. (2020). Language Models are Few-Shot Learners. NeurIPS 2020. arXiv:2005.14165.
- Hoffmann, J., Borgeaud, S., Mensch, A., et al. (2022). Training Compute-Optimal Large Language Models. arXiv:2203.15556.
- Chen, X., Liang, C., Huang, D., et al. (2023). Symbolic Discovery of Optimization Algorithms (Lion). arXiv:2302.06675.
- Micikevicius, P., Narang, S., Alben, J., et al. (2018). Mixed Precision Training. ICLR 2018. arXiv:1710.03740.
- Hu, S., Tu, Y., Han, X., et al. (2024). MiniCPM: Unveiling the Potential of Small Language Models with Scalable Training Strategies. arXiv:2404.06395.
Written by Suchinthaka Wanninayaka
AI/ML Researcher exploring semantic communications, diffusion models, and language model systems. Writing about deep learning from theory to production.
Continue the Series
Transformer Deep Dive: Part 4 - FFN Modifications
22 min read
Next ArticleTransformer Deep Dive: Part 6 - Inference Optimization
25 min read
Related Articles
Responses
No responses yet. Be the first to share your thoughts!