In Part 1, we revisited the original Transformer architecture from Vaswani et al. (2017). If you dropped that 2017 model into a modern training pipeline, it would fail spectacularly -- training would diverge, gradients would explode, and you would spend weeks debugging learning rate warmup schedules. The architecture that powers GPT-4, LLaMA 3, and Mistral has diverged from the original design in several fundamental ways, each motivated by concrete failure modes discovered through years of scaling experiments.

This post examines three critical architectural shifts: the move to decoder-only models, the repositioning of layer normalization, and the simplification of normalization itself. These are not incremental improvements. They represent hard-won lessons about what actually matters when you scale transformers to hundreds of billions of parameters.

1. The Rise of Decoder-Only Architecture

A Brief History

The original Transformer was an encoder-decoder model designed for machine translation. The encoder processes the source sentence bidirectionally, and the decoder generates the target sentence autoregressively, attending to the encoder output via cross-attention. This was a natural architecture for seq2seq tasks, but the field quickly discovered that simpler variants could be equally powerful.

In 2018, two competing paradigms emerged simultaneously. BERT (Devlin et al., 2018) took the encoder half, trained it with a masked language modeling objective, and achieved state-of-the-art results on classification and understanding tasks. GPT (Radford et al., 2018) took the decoder half, trained it with next-token prediction, and showed surprisingly strong zero-shot performance on diverse tasks.

By 2020, with GPT-3 demonstrating in-context learning at scale, the decoder-only paradigm had effectively won. Today, nearly every frontier model -- LLaMA, Mistral, Claude, GPT-4, Gemini, Qwen -- uses a decoder-only architecture.

Three Architecture Families

ArchitectureAttention PatternTraining ObjectiveNotable Models
Encoder-DecoderBidirectional (enc) + Causal (dec) + Cross-attentionSpan corruption, translationT5, BART, mBART, Flan-T5
Encoder-OnlyBidirectionalMasked Language ModelingBERT, RoBERTa, DeBERTa
Decoder-OnlyCausal (unidirectional)Next-token predictionGPT, LLaMA, Mistral, PaLM

Why Decoder-Only Won

The dominance of decoder-only models was not preordained. Several concrete factors drove this convergence:

Unified training objective. Next-token prediction is the simplest possible objective. There are no masked spans to construct, no separate encoder and decoder losses to balance, and no architectural hyperparameters for cross-attention layers. The autoregressive objective naturally decomposes the joint probability of a sequence:

P(x1,x2,,xT)=t=1TP(xtx1,,xt1)P(x_1, x_2, \ldots, x_T) = \prod_{t=1}^{T} P(x_t \mid x_1, \ldots, x_{t-1})

This means the training loss is simply the negative log-likelihood:

L(θ)=t=1TlogPθ(xtx1:t1)\mathcal{L}(\theta) = -\sum_{t=1}^{T} \log P_\theta(x_t \mid x_{1:t-1})

Every token in every training sequence contributes a supervision signal. There is no wasted computation on [MASK] tokens that appear only 15% of the time, as in BERT-style training.

Scaling simplicity. With one objective and one architecture, the only decisions are model size, data, and compute. This aligns perfectly with the scaling laws discovered by Kaplan et al. (2020) and later refined by Hoffmann et al. (2022, "Chinchilla"). The research community converged on a simple recipe: take a decoder-only transformer, scale it up, and feed it more data.

Emergent in-context learning. Perhaps the most surprising property of large decoder-only models is in-context learning (ICL): the ability to perform new tasks by conditioning on a few examples in the prompt, without any gradient updates. This effectively turns a single model into a general-purpose task solver, eliminating the need for task-specific architectures.

KV-cache efficiency. During autoregressive generation, decoder-only models naturally support a key-value cache. At each timestep, we only need to compute the query for the new token and attend to the cached keys and values from all previous tokens. This makes generation O(T)O(T) per token rather than O(T2)O(T^2). Encoder-decoder models require maintaining both an encoder KV-cache and a decoder KV-cache, with cross-attention adding complexity.

Task unification via prompting. A decoder-only model can handle classification, summarization, translation, reasoning, and code generation -- all through different prompt formats. The model's input and output share the same vocabulary and representation space, eliminating the need for task-specific heads.

Causal Masking in Detail

The defining feature of a decoder-only model is the causal attention mask, which ensures that each token can only attend to itself and preceding tokens. This is implemented by adding a mask MM to the attention scores before softmax:

CausalAttention(Q,K,V)=softmax(QKdk+M)V\text{CausalAttention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}} + M\right)V

where MM is an upper-triangular matrix of -\infty values:

Mij={0if ijif i<jM_{ij} = \begin{cases} 0 & \text{if } i \geq j \\ -\infty & \text{if } i < j \end{cases}

After the softmax, positions with -\infty scores become zero, effectively preventing information from flowing backward in the sequence. Here is a complete implementation:

import torch
import torch.nn as nn
import torch.nn.functional as F


def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
    """Create a causal attention mask.

    Returns a (seq_len, seq_len) tensor where future positions
    are set to -inf and past/current positions are 0.
    """
    mask = torch.triu(
        torch.ones(seq_len, seq_len, device=device),
        diagonal=1
    )
    return mask.masked_fill(mask == 1, float('-inf'))


def causal_self_attention(
    x: torch.Tensor,
    W_q: nn.Linear,
    W_k: nn.Linear,
    W_v: nn.Linear,
    n_heads: int
) -> torch.Tensor:
    """Causal multi-head self-attention.

    Args:
        x: Input tensor of shape (batch, seq_len, d_model)
        W_q, W_k, W_v: Projection layers
        n_heads: Number of attention heads
    """
    B, T, C = x.shape
    head_dim = C // n_heads

    # Project to Q, K, V
    q = W_q(x).view(B, T, n_heads, head_dim).transpose(1, 2)
    k = W_k(x).view(B, T, n_heads, head_dim).transpose(1, 2)
    v = W_v(x).view(B, T, n_heads, head_dim).transpose(1, 2)

    # Scaled dot-product attention with causal mask
    scale = head_dim ** -0.5
    attn = (q @ k.transpose(-2, -1)) * scale

    # Apply causal mask
    mask = create_causal_mask(T, x.device)
    attn = attn + mask  # Broadcasting: (B, H, T, T) + (T, T)

    attn = F.softmax(attn, dim=-1)
    out = attn @ v  # (B, H, T, head_dim)

    # Recombine heads
    out = out.transpose(1, 2).contiguous().view(B, T, C)
    return out

The resulting attention pattern looks like this (1 = attends, 0 = masked):

Pos 0Pos 1Pos 2Pos 3
Pos 01000
Pos 11100
Pos 21110
Pos 31111

Prefix LM: A Hybrid Approach

It is worth noting that some models use a prefix LM pattern, where a prefix of the sequence uses bidirectional attention (no causal mask) and the remainder uses causal attention. This can be seen as a decoder-only model where the prompt portion gets bidirectional context. U-PaLM and some T5 variants explored this approach, though pure causal masking remains dominant.

2. Pre-Layer Normalization

The Training Instability Problem

The original Transformer placed layer normalization after the residual connection -- a design now called Post-Layer Normalization (Post-LN). This worked for the relatively small models of 2017, but as researchers tried to scale up, they hit a wall: training became increasingly unstable, requiring carefully tuned learning rate warmup schedules, and often diverging entirely for deeper models.

Xiong et al. (2020) provided a theoretical explanation. In Post-LN, the expected gradient norm at the output layer grows with depth, while the expected gradient at earlier layers can vanish. This creates a precarious optimization landscape that demands very careful warmup to avoid divergence.

Post-LN vs Pre-LN: Structural Comparison

The difference is a simple reordering of operations, but the consequences are profound.

Post-LN (Original Transformer):

hl=LayerNorm(xl+Attn(xl))h_l' = \text{LayerNorm}\big(x_l + \text{Attn}(x_l)\big) xl+1=LayerNorm(hl+FFN(hl))x_{l+1} = \text{LayerNorm}\big(h_l' + \text{FFN}(h_l')\big)

Pre-LN (Modern):

hl=xl+Attn(LayerNorm(xl))h_l' = x_l + \text{Attn}\big(\text{LayerNorm}(x_l)\big) xl+1=hl+FFN(LayerNorm(hl))x_{l+1} = h_l' + \text{FFN}\big(\text{LayerNorm}(h_l')\big)

Pre-LN vs Post-LN architecture comparison. In Post-LN, normalization sits on the residual path, creating gradient bottlenecks. In Pre-LN, the residual stream flows unimpeded from input to output.

Why This Matters: Gradient Flow Analysis

The critical insight is about the residual stream. In Pre-LN, if we unroll the residual connections across LL layers, the output of the network can be written as:

xL=x0+l=1LFl(LayerNorm(xl1))x_L = x_0 + \sum_{l=1}^{L} F_l\big(\text{LayerNorm}(x_{l-1})\big)

where FlF_l represents the sublayer function (attention or FFN) at layer ll. This has a direct additive path from input x0x_0 to output xLx_L -- the residual stream is never passed through a normalization layer.

Taking the gradient with respect to parameters θ\theta in an early layer:

Lθl=LxLxLθl\frac{\partial \mathcal{L}}{\partial \theta_l} = \frac{\partial \mathcal{L}}{\partial x_L} \cdot \frac{\partial x_L}{\partial \theta_l}

Because xL=x0+Fl()x_L = x_0 + \sum F_l(\cdot), the gradient flows directly from the loss back to any layer without being multiplicatively attenuated by intervening normalization layers. This is analogous to how ResNets solved the vanishing gradient problem in CNNs.

In Post-LN, by contrast, the normalization sits directly on the residual path. Each LayerNorm introduces a Jacobian that can attenuate or amplify gradients, and these effects compound across layers. Xiong et al. (2020) showed that the gradient norm at initialization follows:

xlLPost-LN1lvsxlLPre-LNconst\|\nabla_{x_l}\mathcal{L}\|_{Post\text{-}LN} \propto \frac{1}{\sqrt{l}} \quad \text{vs} \quad \|\nabla_{x_l}\mathcal{L}\|_{Pre\text{-}LN} \approx \text{const}

This explains why Post-LN requires extensive learning rate warmup (often thousands of steps) while Pre-LN can begin training with the full learning rate immediately.

Practical Comparison

PropertyPost-LNPre-LN
Learning rate warmupEssential (thousands of steps)Minimal or unnecessary
Maximum stable learning rateSmallerLarger
Training stability for deep modelsFragile, prone to divergenceRobust
Final performance (when training succeeds)Marginally better in some casesComparable
Used byOriginal Transformer, early BERTGPT-2/3, LLaMA, Mistral, PaLM

The marginal performance advantage of Post-LN has motivated some recent work on stabilizing it (e.g., Admin initialization), but in practice, the stability advantages of Pre-LN have made it the universal default for large-scale training.

Pre-LN Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F


class PreLNTransformerBlock(nn.Module):
    """A single transformer block with Pre-Layer Normalization.

    This is the standard building block used in GPT-2, GPT-3,
    LLaMA, and most modern decoder-only LLMs.
    """

    def __init__(
        self,
        d_model: int,
        n_heads: int,
        d_ff: int,
        dropout: float = 0.0,
        bias: bool = False
    ):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

        self.attention = nn.MultiheadAttention(
            d_model, n_heads,
            dropout=dropout,
            bias=bias,
            batch_first=True
        )
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=bias),
            nn.GELU(),
            nn.Linear(d_ff, d_model, bias=bias),
            nn.Dropout(dropout),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: torch.Tensor | None = None
    ) -> torch.Tensor:
        # ---- Attention sub-block (Pre-LN) ----
        # Normalize BEFORE the sublayer
        normed = self.ln1(x)
        attn_out, _ = self.attention(
            normed, normed, normed,
            attn_mask=attn_mask,
            need_weights=False
        )
        # Residual connection bypasses normalization
        x = x + self.dropout(attn_out)

        # ---- FFN sub-block (Pre-LN) ----
        x = x + self.ffn(self.ln2(x))
        return x


class PreLNTransformer(nn.Module):
    """Complete Pre-LN decoder-only transformer.

    Note the final LayerNorm after the last block -- this is
    essential because the residual stream is unnormalized.
    """

    def __init__(
        self,
        vocab_size: int,
        d_model: int = 768,
        n_heads: int = 12,
        n_layers: int = 12,
        d_ff: int = 3072,
        max_seq_len: int = 2048,
        dropout: float = 0.1
    ):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        self.dropout = nn.Dropout(dropout)

        self.blocks = nn.ModuleList([
            PreLNTransformerBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])

        # Final LayerNorm -- critical for Pre-LN architecture
        self.ln_final = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        B, T = input_ids.shape
        positions = torch.arange(T, device=input_ids.device)

        x = self.token_emb(input_ids) + self.pos_emb(positions)
        x = self.dropout(x)

        # Causal mask
        mask = create_causal_mask(T, input_ids.device)

        for block in self.blocks:
            x = block(x, attn_mask=mask)

        # Final normalization before output projection
        x = self.ln_final(x)
        logits = self.lm_head(x)
        return logits

The Final LayerNorm

Notice the ln_final in the model above. In a Pre-LN architecture, the residual stream accumulates contributions from every layer without ever being normalized on the main path. By the time we reach the last layer, the activations can have grown substantially. The final LayerNorm brings the representation back to a normalized scale before projecting to vocabulary logits:

logits=WvocabLayerNorm(xL)\text{logits} = W_{\text{vocab}} \cdot \text{LayerNorm}(x_L)

Omitting this final normalization typically leads to training instability or poor performance. Every major Pre-LN model (GPT-2, LLaMA, Mistral, PaLM) includes it.

3. RMSNorm: Simpler and Faster Normalization

LayerNorm Revisited

Standard Layer Normalization (Ba et al., 2016) computes both the mean and variance of the activations, then re-centers and re-scales:

LayerNorm(x)=xμσγ+β\text{LayerNorm}(x) = \frac{x - \mu}{\sigma} \odot \gamma + \beta

where:

μ=1di=1dxi,σ=1di=1d(xiμ)2+ϵ\mu = \frac{1}{d}\sum_{i=1}^{d} x_i, \qquad \sigma = \sqrt{\frac{1}{d}\sum_{i=1}^{d}(x_i - \mu)^2 + \epsilon}

Here γ\gamma and β\beta are learnable gain and bias parameters, each of dimension dd. The computation involves two passes over the data (one for mean, one for variance), plus the subtraction and division.

The RMSNorm Simplification

Zhang and Sennrich (2019) proposed Root Mean Square Layer Normalization, which eliminates the mean computation entirely:

RMSNorm(x)=xRMS(x)γ\text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \odot \gamma

where:

RMS(x)=1di=1dxi2+ϵ\text{RMS}(x) = \sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2 + \epsilon}

Two things are different here. First, there is no mean subtraction -- the input is divided by its root mean square directly. Second, there is no bias parameter β\beta, only a gain γ\gamma.

Why Removing Mean Subtraction Works

The key insight is a mathematical relationship between the RMS and the standard deviation. The variance can be decomposed as:

σ2=E[x2](E[x])2=RMS(x)2μ2\sigma^2 = \mathbb{E}[x^2] - (\mathbb{E}[x])^2 = \text{RMS}(x)^2 - \mu^2

Therefore:

RMS(x)2=μ2+σ2\text{RMS}(x)^2 = \mu^2 + \sigma^2

For neural network activations, especially in deeper layers with residual connections, the mean μ\mu tends to be close to zero. When μ0\mu \approx 0, we get RMS(x)σ\text{RMS}(x) \approx \sigma, and RMSNorm becomes approximately equivalent to LayerNorm (without the centering).

Zhang and Sennrich (2019) provided empirical evidence that the re-centering operation (mean subtraction) contributes negligibly to the success of LayerNorm, while the re-scaling operation (division by a measure of spread) is what actually stabilizes training. This is an elegant instance of removing unnecessary computation without sacrificing model quality.

Computational Savings

OperationLayerNormRMSNorm
Reduction passes2 (mean, then variance)1 (sum of squares)
Mean computationRequiredNot needed
Mean subtractionRequiredNot needed
Learnable parametersγ\gamma and β\beta (2d)γ\gamma only (d)
Wall-clock speedupBaseline~10-15% faster
Memory for parameters2d floatsd floats

The 10-15% speedup may seem modest, but normalization is applied at every sublayer of every transformer block. In a 32-layer LLaMA model, that is 64 normalization operations per forward pass. At scale, this adds up to meaningful savings in both training and inference time.

Implementation

import torch
import torch.nn as nn


class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization.

    Used by LLaMA, LLaMA 2, LLaMA 3, Mistral, Gemma,
    and most modern LLMs as a drop-in replacement for LayerNorm.
    """

    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))

    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        # RMS = sqrt(mean(x^2) + eps)
        return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Cast to float32 for numerical stability, then back
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

A few implementation details are worth noting:

  1. torch.rsqrt computes 1/x1/\sqrt{x} in a single fused operation, which is faster than computing sqrt and then dividing.
  2. Float32 accumulation: The norm computation is done in float32 even if the input is in bfloat16. This prevents numerical issues when squaring small values. The result is cast back to the original dtype afterward.
  3. No bias parameter: There is no additive bias, which means the output is purely a scaled version of the input direction.

A LLaMA-Style Block with RMSNorm

Combining Pre-LN with RMSNorm gives us the standard building block of modern LLMs:

class LLaMABlock(nn.Module):
    """Transformer block following the LLaMA architecture.

    Uses Pre-LN with RMSNorm instead of LayerNorm.
    """

    def __init__(self, d_model: int, n_heads: int, d_ff: int):
        super().__init__()
        self.attn_norm = RMSNorm(d_model)
        self.ffn_norm = RMSNorm(d_model)

        self.attention = nn.MultiheadAttention(
            d_model, n_heads,
            bias=False,
            batch_first=True
        )
        # SwiGLU FFN (covered in Part 4)
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)
        self.w3 = nn.Linear(d_model, d_ff, bias=False)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        # Pre-RMSNorm + Attention
        h = self.attn_norm(x)
        attn_out, _ = self.attention(h, h, h, attn_mask=mask)
        x = x + attn_out

        # Pre-RMSNorm + SwiGLU FFN
        h = self.ffn_norm(x)
        x = x + self.w2(F.silu(self.w1(h)) * self.w3(h))
        return x

Adoption Across Models

ModelYearNormalizationPosition
Original Transformer2017LayerNormPost-LN
GPT-22019LayerNormPre-LN
GPT-32020LayerNormPre-LN
PaLM2022LayerNormPre-LN (parallel)
LLaMA2023RMSNormPre-LN
LLaMA 22023RMSNormPre-LN
Mistral 7B2023RMSNormPre-LN
Gemma2024RMSNormPre-LN
LLaMA 32024RMSNormPre-LN

The trend is clear: RMSNorm with Pre-LN positioning has become the de facto standard.

A Note on QK-Norm

An emerging technique is QK-Norm, where an additional RMSNorm is applied to the query and key vectors before computing attention scores. This prevents attention logits from growing too large, which can cause issues with float16/bfloat16 precision:

Attn(Q,K,V)=softmax(RMSNorm(Q)RMSNorm(K)dk)V\text{Attn}(Q, K, V) = \text{softmax}\left(\frac{\text{RMSNorm}(Q) \cdot \text{RMSNorm}(K)^\top}{\sqrt{d_k}}\right)V

Models like Gemma 2 and some LLaMA 3 variants use QK-Norm for additional training stability, especially at very large scales.

4. Removing Bias Terms

One additional change worth mentioning: most modern LLMs remove bias terms from linear layers throughout the model. The original Transformer used biases in attention projections (WQ,WK,WV,WOW_Q, W_K, W_V, W_O), FFN layers, and layer normalization.

Modern models like LLaMA set bias=False everywhere. The rationale is:

  1. Parameter efficiency: Bias terms add dd parameters per linear layer, which is negligible compared to d×dd \times d weight matrices but adds implementation complexity.
  2. RMSNorm has no bias: Since RMSNorm already omits the bias term β\beta, removing biases from linear layers is consistent.
  3. Empirical finding: Multiple ablation studies have shown no degradation from removing biases.

Summary: Original Transformer vs Modern LLMs

ComponentOriginal Transformer (2017)Modern LLMs (2023+)
ArchitectureEncoder-DecoderDecoder-Only
Training objectiveSeq2Seq cross-entropyNext-token prediction
LayerNorm positionPost-LNPre-LN
NormalizationLayerNorm (with bias)RMSNorm (no bias)
Bias termsYes (everywhere)No (removed)
Positional encodingSinusoidal (additive)RoPE (multiplicative)
FFN activationReLUSwiGLU
Attention variantMulti-Head (MHA)Grouped-Query (GQA)

Each of these changes is motivated by a specific failure mode or efficiency improvement discovered through scaling. The modern LLM is not a minor evolution of the original Transformer -- it is a substantially rearchitected system, rebuilt piece by piece as researchers discovered what breaks at scale.


In the next post, we will dive into Part 3: Attention Modifications -- the evolution of positional encoding from sinusoidal to RoPE, the KV-cache efficiency improvements of Multi-Query and Grouped-Query Attention, and how FlashAttention exploits GPU memory hierarchies to make attention both faster and more memory-efficient.


References

  1. Vaswani, A., Shazeer, N., Parmar, N., et al. (2017). "Attention Is All You Need." NeurIPS 2017.
  2. Radford, A., Narasimhan, K., Salimans, T., Sutskever, I. (2018). "Improving Language Understanding by Generative Pre-Training." OpenAI.
  3. Devlin, J., Chang, M.-W., Lee, K., Toutanova, K. (2018). "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding." NAACL 2019.
  4. Xiong, R., Yang, Y., He, D., et al. (2020). "On Layer Normalization in the Transformer Architecture." ICML 2020.
  5. Zhang, B. and Sennrich, R. (2019). "Root Mean Square Layer Normalization." NeurIPS 2019.
  6. Ba, J. L., Kiros, J. R., Hinton, G. E. (2016). "Layer Normalization." arXiv:1607.06450.
  7. Kaplan, J., McCandlish, S., Henighan, T., et al. (2020). "Scaling Laws for Neural Language Models." arXiv:2001.08361.
  8. Hoffmann, J., Borgeaud, S., Mensch, A., et al. (2022). "Training Compute-Optimal Large Language Models." NeurIPS 2022.
  9. Touvron, H., Lavril, T., Izacard, G., et al. (2023). "LLaMA: Open and Efficient Foundation Language Models." arXiv:2302.13971.
  10. Touvron, H., Martin, L., Stone, K., et al. (2023). "Llama 2: Open Foundation and Fine-Tuned Chat Models." arXiv:2307.09288.
Share:
SW

Written by Suchinthaka Wanninayaka

AI/ML Researcher exploring semantic communications, diffusion models, and language model systems. Writing about deep learning from theory to production.

Responses

?

No responses yet. Be the first to share your thoughts!