All Articles
transformersattentiondeep-learningllmarchitecture

Transformer Deep Dive: Part 2 - Architecture Changes

How modern LLMs evolved from the original Transformer - decoder-only architecture, Pre-Layer Normalization, and RMSNorm. The fundamental architectural shifts that power GPT, LLaMA, and Mistral.

SW

Suchinthaka W.

January 16, 2025 · 5 min read

The Transformer architecture from 2017 revolutionized NLP. However, modern LLMs like GPT-4, LLaMA 3, and Mistral differ significantly from the original design. This post explores three critical architectural innovations.

1. Decoder-Only Architecture

Historical Context

The original Transformer used an encoder-decoder architecture for sequence-to-sequence tasks like translation. But a shift happened:

2017: Transformer (Enc-Dec)
2018: BERT (Encoder-only), GPT-1 (Decoder-only)
2019: GPT-2
2020: GPT-3
2023+: LLaMA, Mistral, Claude, GPT-4 (All Decoder-only)

Three Architecture Families

| Architecture | Attention Type | Primary Task | Examples | |--------------|---------------|--------------|----------| | Encoder-Decoder | Bidir + Cross | Seq2Seq | T5, BART | | Encoder-Only | Bidirectional | Classification | BERT, RoBERTa | | Decoder-Only | Causal | Generation | GPT, LLaMA |

Why Decoder-Only Won

Decoder-only architectures became dominant because they offer a unified objective (next-token prediction) that scales remarkably well:

  1. Simplicity: One model architecture, one training objective
  2. Scalability: Easier to scale parameters and data
  3. Unified Interface: Same model handles all tasks via prompting
  4. Emergent Abilities: In-context learning emerges at scale
  5. Efficient Inference: KV-cache works naturally

The Unified Objective

The decoder-only model optimizes next token prediction:

L=t=1TlogP(xtx1,x2,,xt1;θ)\mathcal{L} = -\sum_{t=1}^{T} \log P(x_t | x_1, x_2, \ldots, x_{t-1}; \theta)

Causal Masking

Standard attention allows each position to attend to all others. Causal attention adds a mask MM where future positions are set to -\infty:

CausalAttention(Q,K,V)=softmax(QKdk+M)V\text{CausalAttention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}} + M\right)V
def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
    """Create causal attention mask."""
    mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
    mask = mask.masked_fill(mask == 1, float('-inf'))
    return mask

The mask looks like:

[1, 0, 0, 0]    (position 0 sees only itself)
[1, 1, 0, 0]    (position 1 sees 0 and itself)
[1, 1, 1, 0]    (position 2 sees 0, 1, and itself)
[1, 1, 1, 1]    (position 3 sees all previous)

2. Pre-Layer Normalization

The Problem with Post-LN

The original transformer used Post-Layer Normalization:

x=LayerNorm(x+SubLayer(x))x' = \text{LayerNorm}(x + \text{SubLayer}(x))

This creates training instabilities as models get deeper.

Post-LN vs Pre-LN

Post-LN (Original):

x → SubLayer → Add(x, output) → LayerNorm → output

Pre-LN (Modern):

x → LayerNorm → SubLayer → Add(x, output) → output

The key difference: In Pre-LN, the residual connection creates a direct path from input to output without any non-linear transformation.

Mathematical Formulation

Post-LN:

h=LayerNorm(x+Attn(x))h' = \text{LayerNorm}(x + \text{Attn}(x)) x=LayerNorm(h+FFN(h))x' = \text{LayerNorm}(h' + \text{FFN}(h'))

Pre-LN:

h=x+Attn(LayerNorm(x))h' = x + \text{Attn}(\text{LayerNorm}(x)) x=h+FFN(LayerNorm(h))x' = h' + \text{FFN}(\text{LayerNorm}(h'))

Gradient Flow

Pre-LN provides much more stable gradient flow:

| Property | Post-LN | Pre-LN | |----------|---------|--------| | Warmup Required | Essential (long) | Optional (short) | | Max Stable LR | Lower | Higher | | Training Stability | Fragile | Robust | | Final Performance | Slightly better* | Comparable |

*When training succeeds

Implementation

class PreLNTransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.attention = nn.MultiheadAttention(d_model, n_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
        )

    def forward(self, x, mask=None):
        # Pre-LN: normalize BEFORE attention
        normed = self.ln1(x)
        attn_out, _ = self.attention(normed, normed, normed, attn_mask=mask)
        x = x + attn_out  # Clean residual path

        # Pre-LN: normalize BEFORE FFN
        x = x + self.ffn(self.ln2(x))
        return x

Final LayerNorm

Most modern Pre-LN models add a final LayerNorm before output:

y=WoutLayerNorm(xL)y = W_{\text{out}} \cdot \text{LayerNorm}(x_L)

This is essential because the residual stream isn't normalized otherwise.

3. RMSNorm

LayerNorm Revisited

Standard Layer Normalization:

LayerNorm(x)=xμσγ+β\text{LayerNorm}(x) = \frac{x - \mu}{\sigma} \odot \gamma + \beta

where:

μ=1di=1dxi,σ=1di=1d(xiμ)2+ϵ\mu = \frac{1}{d}\sum_{i=1}^{d} x_i, \quad \sigma = \sqrt{\frac{1}{d}\sum_{i=1}^{d} (x_i - \mu)^2 + \epsilon}

The RMSNorm Simplification

Root Mean Square Layer Normalization removes mean-centering:

RMSNorm(x)=xRMS(x)γ\text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \odot \gamma

where:

RMS(x)=1di=1dxi2+ϵ\text{RMS}(x) = \sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2 + \epsilon}

Key Insight: RMSNorm hypothesizes that re-centering (mean subtraction) provides negligible benefit, while re-scaling is what matters.

Why Mean Subtraction May Be Unnecessary

The relationship:

RMS(x)2=μ2+σ2\text{RMS}(x)^2 = \mu^2 + \sigma^2

For activations with μ0\mu \approx 0: RMS(x)σ\text{RMS}(x) \approx \sigma

Computational Savings

| Operation | LayerNorm | RMSNorm | |-----------|-----------|---------| | Mean computation | Required | Not needed | | Variance computation | Requires mean | Direct from sum of squares | | Parameters | γ and β | γ only | | Compute savings | Baseline | ~10-15% faster |

Implementation

class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(d_model))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # RMS: sqrt of mean of squares
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight

Who Uses RMSNorm?

| Model | Normalization | |-------|--------------| | GPT-2/3 | LayerNorm | | LLaMA | RMSNorm | | LLaMA 2/3 | RMSNorm | | Mistral | RMSNorm | | Gemma | RMSNorm |

Summary: Original vs Modern

| Component | Original (2017) | Modern LLMs | |-----------|-----------------|-------------| | Architecture | Encoder-Decoder | Decoder-Only | | LayerNorm Position | Post-LN | Pre-LN | | Normalization | LayerNorm | RMSNorm | | Bias Terms | Yes | Often removed |


In the next post, we'll explore Part 3: Attention Modifications - from positional encoding evolution to RoPE, Multi-Query Attention, Grouped Query Attention, and FlashAttention.

Did you find this helpful?
Share:

Responses

Be the first to share your thoughts!