Transformer Deep Dive: Part 2 - Architecture Changes
How modern LLMs evolved from the original Transformer - decoder-only architecture, Pre-Layer Normalization, and RMSNorm. The fundamental architectural shifts that power GPT, LLaMA, and Mistral.
Suchinthaka W.
January 16, 2025 · 5 min read
The Transformer architecture from 2017 revolutionized NLP. However, modern LLMs like GPT-4, LLaMA 3, and Mistral differ significantly from the original design. This post explores three critical architectural innovations.
1. Decoder-Only Architecture
Historical Context
The original Transformer used an encoder-decoder architecture for sequence-to-sequence tasks like translation. But a shift happened:
2017: Transformer (Enc-Dec)
2018: BERT (Encoder-only), GPT-1 (Decoder-only)
2019: GPT-2
2020: GPT-3
2023+: LLaMA, Mistral, Claude, GPT-4 (All Decoder-only)
Three Architecture Families
| Architecture | Attention Type | Primary Task | Examples | |--------------|---------------|--------------|----------| | Encoder-Decoder | Bidir + Cross | Seq2Seq | T5, BART | | Encoder-Only | Bidirectional | Classification | BERT, RoBERTa | | Decoder-Only | Causal | Generation | GPT, LLaMA |
Why Decoder-Only Won
Decoder-only architectures became dominant because they offer a unified objective (next-token prediction) that scales remarkably well:
- Simplicity: One model architecture, one training objective
- Scalability: Easier to scale parameters and data
- Unified Interface: Same model handles all tasks via prompting
- Emergent Abilities: In-context learning emerges at scale
- Efficient Inference: KV-cache works naturally
The Unified Objective
The decoder-only model optimizes next token prediction:
Causal Masking
Standard attention allows each position to attend to all others. Causal attention adds a mask where future positions are set to :
def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
"""Create causal attention mask."""
mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
return mask
The mask looks like:
[1, 0, 0, 0] (position 0 sees only itself)
[1, 1, 0, 0] (position 1 sees 0 and itself)
[1, 1, 1, 0] (position 2 sees 0, 1, and itself)
[1, 1, 1, 1] (position 3 sees all previous)
2. Pre-Layer Normalization
The Problem with Post-LN
The original transformer used Post-Layer Normalization:
This creates training instabilities as models get deeper.
Post-LN vs Pre-LN
Post-LN (Original):
x → SubLayer → Add(x, output) → LayerNorm → output
Pre-LN (Modern):
x → LayerNorm → SubLayer → Add(x, output) → output
The key difference: In Pre-LN, the residual connection creates a direct path from input to output without any non-linear transformation.
Mathematical Formulation
Post-LN:
Pre-LN:
Gradient Flow
Pre-LN provides much more stable gradient flow:
| Property | Post-LN | Pre-LN | |----------|---------|--------| | Warmup Required | Essential (long) | Optional (short) | | Max Stable LR | Lower | Higher | | Training Stability | Fragile | Robust | | Final Performance | Slightly better* | Comparable |
*When training succeeds
Implementation
class PreLNTransformerBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ff: int):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.attention = nn.MultiheadAttention(d_model, n_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model),
)
def forward(self, x, mask=None):
# Pre-LN: normalize BEFORE attention
normed = self.ln1(x)
attn_out, _ = self.attention(normed, normed, normed, attn_mask=mask)
x = x + attn_out # Clean residual path
# Pre-LN: normalize BEFORE FFN
x = x + self.ffn(self.ln2(x))
return x
Final LayerNorm
Most modern Pre-LN models add a final LayerNorm before output:
This is essential because the residual stream isn't normalized otherwise.
3. RMSNorm
LayerNorm Revisited
Standard Layer Normalization:
where:
The RMSNorm Simplification
Root Mean Square Layer Normalization removes mean-centering:
where:
Key Insight: RMSNorm hypothesizes that re-centering (mean subtraction) provides negligible benefit, while re-scaling is what matters.
Why Mean Subtraction May Be Unnecessary
The relationship:
For activations with :
Computational Savings
| Operation | LayerNorm | RMSNorm | |-----------|-----------|---------| | Mean computation | Required | Not needed | | Variance computation | Requires mean | Direct from sum of squares | | Parameters | γ and β | γ only | | Compute savings | Baseline | ~10-15% faster |
Implementation
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
# RMS: sqrt of mean of squares
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return x / rms * self.weight
Who Uses RMSNorm?
| Model | Normalization | |-------|--------------| | GPT-2/3 | LayerNorm | | LLaMA | RMSNorm | | LLaMA 2/3 | RMSNorm | | Mistral | RMSNorm | | Gemma | RMSNorm |
Summary: Original vs Modern
| Component | Original (2017) | Modern LLMs | |-----------|-----------------|-------------| | Architecture | Encoder-Decoder | Decoder-Only | | LayerNorm Position | Post-LN | Pre-LN | | Normalization | LayerNorm | RMSNorm | | Bias Terms | Yes | Often removed |
In the next post, we'll explore Part 3: Attention Modifications - from positional encoding evolution to RoPE, Multi-Query Attention, Grouped Query Attention, and FlashAttention.
Transformer Deep Dive: Part 1 - The Original Transformer (2017)
NextTransformer Deep Dive: Part 3 - Attention Modifications
Related Articles
Responses
Be the first to share your thoughts!