Transformer Deep Dive: Part 5 - Training Improvements
Modern training techniques for LLMs - AdamW optimizer, learning rate schedules, mixed precision training (FP16/BF16), gradient checkpointing, and distributed training strategies.
Suchinthaka W.
January 19, 2025 · 7 min read
Training large language models requires carefully orchestrated techniques that address three fundamental challenges: optimization stability across billions of parameters, memory constraints on limited GPU resources, and computational efficiency across distributed systems.
The Training Loop
The training of a transformer involves iteratively updating parameters to minimize a loss function:
where the specific form of Update(·) defines the optimizer, is the learning rate (often scheduled), and computation may be distributed across devices with reduced precision.
Optimizers
Adam (Adaptive Moment Estimation)
Adam combines momentum with adaptive learning rates per parameter:
Typical values: , ,
AdamW: Decoupled Weight Decay
A critical discovery: L2 regularization and weight decay are not equivalent in Adam!
The Problem: In standard Adam with L2 regularization, the regularization term is scaled by the adaptive learning rate:
Parameters with larger gradients receive less regularization—the opposite of what we want.
AdamW Solution: Apply weight decay directly to parameters, outside the adaptive update:
# Simplified AdamW
for param in model.parameters():
if param.grad is None:
continue
# Standard Adam update
m = beta1 * m + (1 - beta1) * param.grad
v = beta2 * v + (1 - beta2) * param.grad ** 2
m_hat = m / (1 - beta1 ** t)
v_hat = v / (1 - beta2 ** t)
# AdamW: weight decay applied directly
param.data -= lr * (m_hat / (v_hat.sqrt() + eps) + weight_decay * param.data)
Optimizer Comparison
| Optimizer | Memory (per param) | Key Feature | |-----------|-------------------|-------------| | SGD | 0 bytes | Simple, needs tuning | | SGD + Momentum | 4 bytes | More stable | | Adam | 8 bytes | Adaptive LR | | AdamW | 8 bytes | Proper weight decay | | Adafactor | 4 bytes* | Memory efficient | | Lion | 4 bytes | Simpler, competitive |
*Adafactor uses factored second moments
Learning Rate Schedules
Warmup
Warmup is critical for transformer training. Start with a small learning rate and gradually increase:
Why Warmup?
- Adam's variance estimate is biased early in training
- Large initial gradients can destabilize training
- Especially important with Pre-LN transformers
Cosine Decay
After warmup, decay the learning rate following a cosine curve:
Linear Decay
Simple linear decrease:
Common Configurations
| Model | Warmup | Decay | Final LR | |-------|--------|-------|----------| | GPT-3 | 375M tokens | Cosine | 10% of max | | LLaMA | 2000 steps | Cosine | 10% of max | | Chinchilla | 1500 steps | Cosine | 10% of max |
def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps):
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return 0.1 + 0.45 * (1 + math.cos(math.pi * progress))
return LambdaLR(optimizer, lr_lambda)
Mixed Precision Training
The Precision Hierarchy
| Format | Bits | Range | Precision | Use Case | |--------|------|-------|-----------|----------| | FP32 | 32 | ±3.4e38 | High | Master weights | | TF32 | 19 | ±3.4e38 | Medium | Tensor cores | | BF16 | 16 | ±3.4e38 | Low | Training | | FP16 | 16 | ±65504 | Medium | Training | | FP8 | 8 | ±448 | Low | Inference |
BF16 vs FP16
FP16 (IEEE Half Precision):
- 1 sign, 5 exponent, 10 mantissa
- Higher precision, limited range
- Needs loss scaling to prevent overflow/underflow
BF16 (Brain Float):
- 1 sign, 8 exponent, 7 mantissa
- Lower precision, same range as FP32
- No loss scaling needed
FP32: [1][8 exponent bits][23 mantissa bits]
FP16: [1][5 exponent bits][10 mantissa bits]
BF16: [1][8 exponent bits][7 mantissa bits]
Mixed Precision Strategy
- Master weights in FP32
- Forward pass in FP16/BF16
- Backward pass in FP16/BF16
- Gradient accumulation in FP32
- Optimizer update in FP32
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler() # For FP16, not needed for BF16
optimizer = torch.optim.AdamW(model.parameters())
for batch in dataloader:
optimizer.zero_grad()
# Forward in mixed precision
with autocast(dtype=torch.bfloat16):
loss = model(batch)
# Backward (scaler only for FP16)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Memory Savings
| Precision | Memory per Param | Relative | |-----------|-----------------|----------| | FP32 weights + FP32 optimizer | 16 bytes | 1× | | FP16 weights + FP32 optimizer | 12 bytes | 0.75× | | BF16 weights + FP32 optimizer | 12 bytes | 0.75× |
Gradient Checkpointing
The Memory Problem
During backpropagation, we need activations from the forward pass. For a model with L layers:
For a 70B model with 80 layers, 2K context, this can exceed 100GB!
Checkpointing Strategy
Trade compute for memory: Don't store all activations. Instead:
- Store activations at "checkpoint" layers only
- During backward pass, recompute activations between checkpoints
import torch.utils.checkpoint as checkpoint
class CheckpointedTransformer(nn.Module):
def __init__(self, layers):
super().__init__()
self.layers = layers
def forward(self, x):
for layer in self.layers:
# Recompute forward pass during backward
x = checkpoint.checkpoint(layer, x, use_reentrant=False)
return x
Memory-Compute Tradeoff
| Strategy | Memory | Compute | |----------|--------|---------| | No checkpointing | O(L) | 1× | | Every layer | O(1) | ~2× | | Every √L layers | O(√L) | ~1.5× |
Distributed Training
Data Parallelism (DP)
Simplest approach: replicate model on each GPU, split data.
GPU 0: Full model, Batch 0
GPU 1: Full model, Batch 1
GPU 2: Full model, Batch 2
GPU 3: Full model, Batch 3
↓
All-Reduce gradients
↓
Synchronized update
Fully Sharded Data Parallelism (FSDP)
Shard model parameters, gradients, and optimizer states across GPUs:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
),
)
Memory per GPU (70B Model)
| Strategy | Params | Grads | Optimizer | Total | |----------|--------|-------|-----------|-------| | No parallelism | 280GB | 280GB | 560GB | 1.1TB | | DP (8 GPUs) | 280GB | 280GB | 560GB | 1.1TB | | FSDP (8 GPUs) | 35GB | 35GB | 70GB | 140GB |
Tensor Parallelism
Split individual layers across GPUs:
Linear Layer: Y = XW
GPU 0: Y_0 = X @ W_0 (first half of columns)
GPU 1: Y_1 = X @ W_1 (second half of columns)
↓
All-Gather
↓
Y = [Y_0, Y_1]
Pipeline Parallelism
Split layers across GPUs, process micro-batches in pipeline:
Time → T0 T1 T2 T3 T4 T5
GPU 0: F0 F1 F2 B0 B1 B2
GPU 1: F0 F1 F2 B0 B1
GPU 2: F0 F1 F2 B0
GPU 3: F0 F1 F2
F = Forward, B = Backward
3D Parallelism
Combine all strategies for maximum scale:
Example: 1024 GPUs = 128 DP × 4 TP × 2 PP
Training Recipe Summary
| Component | Recommendation | |-----------|---------------| | Optimizer | AdamW (, ) | | Weight Decay | 0.1 | | Warmup | 1-2% of training | | LR Schedule | Cosine decay to 10% | | Precision | BF16 mixed precision | | Gradient Clipping | 1.0 | | Batch Size | As large as memory allows |
In the next post, we'll explore Part 6: Inference Optimization - KV-cache, quantization, speculative decoding, and continuous batching for production deployment.
Transformer Deep Dive: Part 4 - FFN Modifications
NextTransformer Deep Dive: Part 6 - Inference Optimization
Related Articles
Responses
Be the first to share your thoughts!