In Part 2, we examined the macro-level architectural shifts that turned the original Transformer into a decoder-only model with Pre-LN and RMSNorm. Now we turn to the heart of the architecture: the attention mechanism itself. Since 2017, researchers have redesigned almost every aspect of attention -- how positions are encoded, how key-value heads are shared, how the computation maps to hardware, and how context windows are extended.

These modifications address four fundamental challenges:

  1. Position encoding: The original Transformer has no inherent notion of token order. Position must be explicitly injected, and the method of injection determines whether the model can generalize to unseen sequence lengths.
  2. KV-cache memory: During autoregressive inference, storing keys and values for all previous tokens across all layers and all heads becomes a major memory bottleneck.
  3. Memory bandwidth: Attention is memory-bound on modern GPUs -- the arithmetic intensity is too low relative to the data movement required.
  4. Quadratic complexity: The O(n2)O(n^2) attention matrix becomes prohibitive for very long sequences.

Each challenge has spawned a family of solutions. This post covers the most impactful ones.

1. The Evolution of Positional Encoding

Why Position Matters

Self-attention is permutation-equivariant: if you shuffle the input tokens, the output is shuffled in the same way. This means a Transformer without position information treats "the cat sat on the mat" identically to "mat the on sat cat the." Position encoding breaks this symmetry by providing each token with information about where it sits in the sequence.

The evolution of positional encoding reflects a deepening understanding of what properties matter:

MethodYearTypeExtrapolationRelative PositionUsed By
Sinusoidal2017Additive, fixedLimitedIndirectOriginal Transformer
Learned2018Additive, learnedNoneNoBERT, GPT-1/2
ALiBi2022Attention biasGoodYes (linear)BLOOM, MPT
RoPE2021/2023Multiplicative, fixedExtensibleYes (rotation)LLaMA, Mistral, GPT-NeoX

Sinusoidal Positional Encoding (2017)

Vaswani et al. used a fixed set of sinusoidal functions to encode position:

PE(pos,2i)=sin(pos100002i/dmodel)PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) PE(pos,2i+1)=cos(pos100002i/dmodel)PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)

These position vectors are added to the token embeddings before entering the first layer. Each dimension ii oscillates at a different frequency, creating a unique "fingerprint" for each position. Vaswani et al. chose sinusoids because the encoding for position pos+kpos + k can be expressed as a linear function of the encoding at pospos, theoretically enabling length extrapolation.

Limitations: In practice, sinusoidal encodings do not extrapolate well beyond training lengths. The position information also gets diluted as it passes through multiple attention and FFN layers, since it is only injected once at the input.

Learned Positional Embeddings (2018)

BERT and GPT-1/2 replaced the fixed sinusoidal encodings with a learnable embedding table -- a matrix PRLmax×dP \in \mathbb{R}^{L_{\max} \times d} where LmaxL_{\max} is the maximum sequence length. Position tt simply looks up row PtP_t and adds it to the token embedding.

This is maximally flexible but has a hard limitation: the model cannot process any sequence longer than LmaxL_{\max}. There is no mechanism for extrapolation, and increasing the context window requires retraining or fine-tuning with a new, larger embedding table.

ALiBi: Attention with Linear Biases (2022)

Press et al. (2022) proposed a radically simple approach: do not encode position in the embeddings at all. Instead, add a position-dependent bias directly to the attention scores:

Attention(Q,K,V)=softmax(QKdk+mBias)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}} + m \cdot \text{Bias}\right)V

where Biasij=ij\text{Bias}_{ij} = -|i - j| is a linear penalty based on distance, and mm is a head-specific slope. Nearer tokens get higher attention scores; distant tokens get lower scores.

ALiBi demonstrates excellent length extrapolation -- models trained on short sequences can attend to much longer sequences at inference time. However, the linear decay may be too rigid for tasks requiring precise long-range dependencies. ALiBi is used in BLOOM and MPT but has largely been superseded by RoPE for most applications.

Rotary Position Embedding (RoPE)

RoPE, introduced by Su et al. (2021), is the positional encoding used by virtually every modern LLM: LLaMA, LLaMA 2/3, Mistral, Qwen, Gemma, and many others. It encodes position by rotating query and key vectors in a position-dependent manner, so that the dot product between a query at position mm and a key at position nn depends only on their relative distance (mn)(m - n).

The Core Idea

Instead of adding position information to embeddings, RoPE applies a rotation matrix Rθ(m)R_\theta^{(m)} to the query and key vectors at position mm:

q~m=Rθ(m)qm,k~n=Rθ(n)kn\tilde{q}_m = R_\theta^{(m)} q_m, \qquad \tilde{k}_n = R_\theta^{(n)} k_n

The attention score between positions mm and nn is then:

q~mk~n=qm(Rθ(m))Rθ(n)kn=qmRθ(nm)kn\tilde{q}_m^\top \tilde{k}_n = q_m^\top (R_\theta^{(m)})^\top R_\theta^{(n)} k_n = q_m^\top R_\theta^{(n-m)} k_n

Because rotation matrices satisfy R1=RR^{-1} = R^\top and R(a)R(b)=R(a+b)R^{(a)} R^{(b)} = R^{(a+b)}, the product depends only on the relative position (nm)(n - m), not the absolute positions.

Mathematical Derivation

Consider the head dimension dd and pair the dimensions (2i,2i+1)(2i, 2i+1) for i=0,1,,d/21i = 0, 1, \ldots, d/2 - 1. Each pair is rotated by an angle θim\theta_i \cdot m where:

θi=1100002i/d\theta_i = \frac{1}{10000^{2i/d}}

The rotation for dimension pair ii at position mm is a 2D rotation matrix:

Ri(m)=(cos(mθi)sin(mθi)sin(mθi)cos(mθi))R_i(m) = \begin{pmatrix} \cos(m\theta_i) & -\sin(m\theta_i) \\ \sin(m\theta_i) & \cos(m\theta_i) \end{pmatrix}

The full rotation matrix Rθ(m)R_\theta^{(m)} is a block-diagonal matrix with d/2d/2 such blocks:

Rθ(m)=(R0(m)R1(m)Rd/21(m))R_\theta^{(m)} = \begin{pmatrix} R_0(m) & & \\ & R_1(m) & \\ & & \ddots \\ & & & R_{d/2-1}(m) \end{pmatrix}

The Complex Number Perspective

There is an elegant way to implement RoPE using complex arithmetic. We can view each pair (q2i,q2i+1)(q_{2i}, q_{2i+1}) as a complex number qi(C)=q2i+jq2i+1q_i^{(\mathbb{C})} = q_{2i} + j \cdot q_{2i+1}. Then rotating by angle mθim\theta_i is simply multiplication by ejmθie^{jm\theta_i}:

q~i(C)=qi(C)ejmθi=(q2i+jq2i+1)(cosmθi+jsinmθi)\tilde{q}_i^{(\mathbb{C})} = q_i^{(\mathbb{C})} \cdot e^{jm\theta_i} = (q_{2i} + jq_{2i+1})(\cos m\theta_i + j\sin m\theta_i)

This is computationally efficient and can be implemented with element-wise complex multiplication.

Implementation

import torch
import torch.nn as nn
import math


def precompute_freqs_cis(
    dim: int,
    max_seq_len: int,
    theta: float = 10000.0
) -> torch.Tensor:
    """Precompute the complex exponentials for RoPE.

    Args:
        dim: Head dimension (must be even).
        max_seq_len: Maximum sequence length to precompute.
        theta: Base frequency (10000 in the original paper).

    Returns:
        Complex tensor of shape (max_seq_len, dim // 2).
    """
    # Frequencies for each dimension pair
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    # Position indices
    t = torch.arange(max_seq_len, dtype=torch.float32)
    # Outer product: (seq_len,) x (dim//2,) -> (seq_len, dim//2)
    freqs = torch.outer(t, freqs)
    # Convert to complex exponentials: e^{i * theta}
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Apply RoPE to query and key tensors.

    Args:
        xq: Query tensor of shape (batch, seq_len, n_heads, head_dim).
        xk: Key tensor of shape (batch, seq_len, n_kv_heads, head_dim).
        freqs_cis: Precomputed complex frequencies (seq_len, head_dim // 2).

    Returns:
        Rotated query and key tensors with same shapes as inputs.
    """
    # Reshape to pairs of 2: (..., head_dim) -> (..., head_dim//2, 2)
    # Then view as complex: (..., head_dim//2)
    xq_complex = torch.view_as_complex(
        xq.float().reshape(*xq.shape[:-1], -1, 2)
    )
    xk_complex = torch.view_as_complex(
        xk.float().reshape(*xk.shape[:-1], -1, 2)
    )

    # Reshape freqs_cis for broadcasting: (seq_len, 1, head_dim//2)
    freqs_cis = freqs_cis.unsqueeze(1)  # Add head dimension

    # Apply rotation via complex multiplication
    xq_rotated = torch.view_as_real(xq_complex * freqs_cis).flatten(-2)
    xk_rotated = torch.view_as_real(xk_complex * freqs_cis).flatten(-2)

    return xq_rotated.type_as(xq), xk_rotated.type_as(xk)

Key Properties of RoPE

  1. Relative position encoding: The dot product q~mk~n\tilde{q}_m^\top \tilde{k}_n depends only on (mn)(m - n), not on absolute positions. This is a theoretically desirable property.
  2. Natural distance decay: Higher-frequency dimensions (larger ii) rotate faster, causing the dot product to decay for distant token pairs. This provides a soft inductive bias toward local attention.
  3. No learnable parameters: RoPE is entirely deterministic -- there are no position embeddings to learn, which simplifies training and reduces parameter count.
  4. Applied at every layer: Unlike additive position encodings that are injected only at the input, RoPE is applied to the queries and keys at every attention layer. This ensures that position information is never diluted.

Context Length Extension with RoPE

One of RoPE's most valuable properties is that it can be extended to sequence lengths far beyond the training window. Several techniques have been developed:

TechniqueMechanismKey Idea
Position Interpolation (PI)mmLtrainLtargetm \to m \cdot \frac{L_{train}}{L_{target}}Linearly interpolate positions to fit within the trained range
NTK-Aware Scalingθθαd/(d2)\theta \to \theta \cdot \alpha^{d/(d-2)}Modify the base frequency to preserve high-frequency components
YaRNDynamic NTK + attention scalingCombine frequency scaling with temperature adjustment
Code LLaMAPI + fine-tuningInterpolate then fine-tune on long-context data

Position Interpolation (Chen et al., 2023) is the simplest: instead of using position mm directly, use mLtrain/Ltargetm \cdot L_{train} / L_{target}. This maps all positions into the range the model was trained on, but compresses the resolution. It requires only a small amount of fine-tuning to work well.

NTK-Aware Scaling modifies the base frequency θ\theta rather than the positions, preserving the relative resolution of high-frequency components while extending low-frequency components. YaRN further refines this with dynamic per-dimension scaling.

2. Multi-Query Attention (MQA) and Grouped-Query Attention (GQA)

The KV-Cache Bottleneck

During autoregressive generation, we maintain a KV-cache: the key and value tensors for all previous tokens, across all layers and all heads. This avoids recomputing them at each generation step, but the memory cost is substantial:

KV Cache Size=2×L×nheads×dhead×T×bytes per element\text{KV Cache Size} = 2 \times L \times n_{\text{heads}} \times d_{\text{head}} \times T \times \text{bytes per element}

For a concrete example, consider LLaMA 2 70B with standard Multi-Head Attention (MHA):

  • L=80L = 80 layers, nheads=64n_{\text{heads}} = 64, dhead=128d_{\text{head}} = 128, using bfloat16 (2 bytes)
  • At 32K context: 2×80×64×128×32768×2=85.92 \times 80 \times 64 \times 128 \times 32768 \times 2 = 85.9 GB

That is 85 GB just for the KV-cache -- more than the model weights themselves. This becomes the primary bottleneck for serving long-context models, especially when batching multiple requests.

Multi-Query Attention (Shazeer, 2019)

Shazeer (2019) proposed a simple but effective solution: use a single key head and a single value head shared across all query heads.

In standard Multi-Head Attention:

  • Q: (B,T,nh,dh)(B, T, n_h, d_h) -- each query head has its own projection
  • K: (B,T,nh,dh)(B, T, n_h, d_h) -- each key head has its own projection
  • V: (B,T,nh,dh)(B, T, n_h, d_h) -- each value head has its own projection

In Multi-Query Attention:

  • Q: (B,T,nh,dh)(B, T, n_h, d_h) -- unchanged
  • K: (B,T,1,dh)(B, T, 1, d_h) -- single shared key head
  • V: (B,T,1,dh)(B, T, 1, d_h) -- single shared value head

The KV-cache is reduced by a factor of nhn_h (e.g., 64x for LLaMA 2 70B). For our earlier example, this drops the cache from 85 GB to about 1.3 GB.

However, MQA can degrade model quality because all query heads are forced to share the same key-value representation. This is a significant information bottleneck.

Grouped-Query Attention (Ainslie et al., 2023)

GQA strikes a middle ground: instead of sharing one KV head across all query heads, it shares one KV head across a group of query heads. With nhn_h query heads and nkvn_{kv} KV heads, each KV head serves a group of nh/nkvn_h / n_{kv} query heads.

Comparison of Multi-Head Attention (MHA), Grouped-Query Attention (GQA), and Multi-Query Attention (MQA). MHA has one KV head per query head. GQA groups multiple query heads per KV head. MQA uses a single KV head for all queries.

MethodQuery HeadsKV HeadsKV Cache ReductionQuality
MHAHHHH1x (baseline)Best
GQAHHGG (where 1<G<H1 < G < H)H/GH/G xNear-MHA
MQAHH1HH xSlight degradation

Typical configurations in practice:

ModelQuery HeadsKV HeadsGroup SizeKV Cache Reduction
LLaMA 2 7B32321 (MHA)1x
LLaMA 2 70B64888x
Mistral 7B32844x
LLaMA 3 8B32844x
LLaMA 3 70B64888x

GQA Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class GroupedQueryAttention(nn.Module):
    """Grouped-Query Attention as used in LLaMA 2/3 and Mistral.

    Supports MHA (n_kv_heads == n_heads), GQA (1 < n_kv_heads < n_heads),
    and MQA (n_kv_heads == 1) as special cases.
    """

    def __init__(
        self,
        d_model: int,
        n_heads: int,
        n_kv_heads: int,
        max_seq_len: int = 4096,
        rope_theta: float = 10000.0
    ):
        super().__init__()
        assert n_heads % n_kv_heads == 0, \
            f"n_heads ({n_heads}) must be divisible by n_kv_heads ({n_kv_heads})"

        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.n_groups = n_heads // n_kv_heads
        self.head_dim = d_model // n_heads

        # Separate projections for Q (full heads) and KV (reduced heads)
        self.wq = nn.Linear(d_model, n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(n_heads * self.head_dim, d_model, bias=False)

        # Precompute RoPE frequencies
        self.register_buffer(
            "freqs_cis",
            precompute_freqs_cis(self.head_dim, max_seq_len, rope_theta),
            persistent=False
        )

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int = 0,
        mask: torch.Tensor | None = None
    ) -> torch.Tensor:
        B, T, _ = x.shape

        # Project to queries, keys, values
        q = self.wq(x).view(B, T, self.n_heads, self.head_dim)
        k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim)
        v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim)

        # Apply RoPE to queries and keys
        freqs = self.freqs_cis[start_pos : start_pos + T]
        q, k = apply_rotary_emb(q, k, freqs)

        # Expand KV heads to match query heads
        # (B, T, n_kv_heads, head_dim) -> (B, T, n_heads, head_dim)
        k = self._repeat_kv(k)
        v = self._repeat_kv(v)

        # Transpose for attention: (B, n_heads, T, head_dim)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Scaled dot-product attention
        scale = self.head_dim ** -0.5
        scores = (q @ k.transpose(-2, -1)) * scale

        if mask is not None:
            scores = scores + mask

        attn_weights = F.softmax(scores, dim=-1, dtype=torch.float32)
        attn_weights = attn_weights.type_as(q)

        # Apply attention to values
        output = attn_weights @ v  # (B, n_heads, T, head_dim)
        output = output.transpose(1, 2).contiguous().view(B, T, -1)

        return self.wo(output)

    def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
        """Repeat KV heads to match the number of query heads.

        (B, T, n_kv_heads, head_dim) -> (B, T, n_heads, head_dim)
        """
        if self.n_groups == 1:
            return x  # MHA: no repetition needed
        return x.repeat_interleave(self.n_groups, dim=2)

A few implementation notes:

  1. repeat_interleave vs expand: Using repeat_interleave explicitly copies the data, which is simpler but uses more memory. An alternative is expand + reshape, which creates a view without copying. For training this difference matters; for inference with KV-cache, the repeated keys/values are computed once.

  2. Softmax in float32: The softmax is computed in float32 even if the inputs are bfloat16. This prevents numerical issues with large attention scores and is a standard practice in all production implementations.

  3. RoPE is applied before KV expansion: The rotation is applied to the nkvn_{kv} KV heads, not the expanded nhn_h heads. This is both correct (each KV head has its own rotation) and efficient (fewer operations).

3. FlashAttention: IO-Aware Attention

The Memory Bandwidth Problem

Standard attention computes the full N×NN \times N attention matrix, writes it to GPU high-bandwidth memory (HBM), then reads it back to compute the output. On modern GPUs, this memory traffic is the bottleneck -- not the floating-point arithmetic.

To understand why, consider the GPU memory hierarchy:

Memory LevelCapacityBandwidthLatency
Registers~256 KB per SMN/A0 cycles
SRAM (shared memory)~20 MB total~19 TB/s~5 cycles
HBM (global memory)40-80 GB2-3.35 TB/s~200 cycles

The ratio of compute to memory bandwidth on an A100 is roughly 156:1 (312 TFLOPS / 2 TB/s). This means for every byte loaded from HBM, we can perform about 156 floating-point operations. Standard attention has an arithmetic intensity far below this threshold -- it is memory-bound.

The standard attention algorithm:

  1. Compute S=QK/dkS = QK^\top / \sqrt{d_k} -- write N×NN \times N matrix to HBM
  2. Compute P=softmax(S)P = \text{softmax}(S) -- read from HBM, write N×NN \times N to HBM
  3. Compute O=PVO = PV -- read from HBM, write output to HBM

For sequence length N=4096N = 4096 and d=128d = 128, the attention matrix SS alone requires 40962×2=324096^2 \times 2 = 32 MB in float16 -- per head, per layer. With 32 heads and 32 layers, that is 32 GB of intermediate memory just for one forward pass.

FlashAttention: Tiling and Online Softmax

Dao et al. (2022) introduced FlashAttention, which never materializes the full N×NN \times N attention matrix. Instead, it computes attention in tiles that fit entirely in SRAM, performing the full attention computation without writing the intermediate SS or PP matrices to HBM.

The algorithm relies on two key insights:

Insight 1: Tiling. Divide QQ, KK, and VV into blocks of size Br×dB_r \times d (for QQ) and Bc×dB_c \times d (for KK, VV), chosen so that all intermediate results fit in SRAM. Each tile computes a partial attention output.

Insight 2: Online softmax. The softmax normalization requires the full row of attention scores to compute the denominator jesij\sum_j e^{s_{ij}}. The online softmax trick (Milakov and Gimelshein, 2018) maintains a running maximum and running sum, allowing the softmax to be computed incrementally as new blocks of KK are processed:

For each new block jj, update:

mnew=max(mold,max(sj))m_{\text{new}} = \max(m_{\text{old}}, \max(s_j)) new=emoldmnewold+kesjkmnew\ell_{\text{new}} = e^{m_{\text{old}} - m_{\text{new}}} \cdot \ell_{\text{old}} + \sum_k e^{s_{jk} - m_{\text{new}}} Onew=emoldmnewoldOold+esjmnewVjnewO_{\text{new}} = \frac{e^{m_{\text{old}} - m_{\text{new}}} \cdot \ell_{\text{old}} \cdot O_{\text{old}} + e^{s_j - m_{\text{new}}} V_j}{\ell_{\text{new}}}

This rescaling ensures numerical stability (subtracting the running maximum) and correctness (properly re-weighting partial results as the global normalization constant changes).

FlashAttention Pseudocode

The algorithm in simplified form:

Algorithm: FlashAttention Forward Pass
Input: Q, K, V in HBM, block sizes B_r, B_c

1. Divide Q into blocks Q_1, ..., Q_{T_r} of size B_r x d
2. Divide K, V into blocks K_1,...,K_{T_c}, V_1,...,V_{T_c} of size B_c x d
3. For each query block Q_i (outer loop):
     a. Load Q_i into SRAM
     b. Initialize: O_i = 0, l_i = 0, m_i = -inf
     c. For each KV block (K_j, V_j) (inner loop):
          i.   Load K_j, V_j into SRAM
          ii.  Compute S_ij = Q_i @ K_j^T / sqrt(d)  [in SRAM]
          iii. Compute block max: m_ij = rowmax(S_ij)
          iv.  Update running max: m_new = max(m_i, m_ij)
          v.   Rescale: P_ij = exp(S_ij - m_new)
          vi.  Update: l_new = exp(m_i - m_new) * l_i + rowsum(P_ij)
          vii. Update: O_i = (exp(m_i - m_new) * l_i * O_i + P_ij @ V_j) / l_new
          viii. m_i = m_new, l_i = l_new
     d. Write O_i to HBM
4. Return O

The critical point is that SijS_{ij} is a small block (e.g., 128×128128 \times 128) that fits in SRAM. It is never written to HBM. The only HBM reads are QQ, KK, VV (once each), and the only HBM write is the output OO.

Performance and Memory Complexity

MethodHBM Reads/WritesMemoryWall-Clock Speed
Standard AttentionO(N2)O(N^2) reads/writesO(N2)O(N^2) for attention matrixBaseline
FlashAttentionO(N2d/M)O(N^2 d / M) reads/writes*O(N)O(N) -- no attention matrix stored2-4x faster
FlashAttention-2Same asymptoticO(N)O(N)~2x faster than FA-1
FlashAttention-3Same asymptoticO(N)O(N)Further gains on Hopper GPUs

*MM is SRAM size. The IO complexity is optimal in the sense that no algorithm can achieve fewer HBM accesses.

FlashAttention-2 and -3

FlashAttention-2 (Dao, 2023) improved on the original with:

  • Parallelizing over the sequence length dimension (outer loop over Q blocks) instead of batch/heads
  • Reducing non-matmul FLOPs (warp shuffles, shared memory reads)
  • Better work partitioning across warps within a thread block

FlashAttention-3 (Dao et al., 2024) targets NVIDIA Hopper (H100) GPUs with:

  • Asynchronous data movement using TMA (Tensor Memory Accelerator)
  • Warp specialization: different warps handle memory operations vs compute
  • FP8 quantization support for further speedups

Using FlashAttention in Practice

PyTorch 2.0+ integrates FlashAttention into scaled_dot_product_attention:

import torch
import torch.nn.functional as F


# PyTorch automatically selects the best attention backend:
# - FlashAttention (for supported dtypes and GPU architectures)
# - Memory-efficient attention (xformers-style)
# - Standard math attention (fallback)
def efficient_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    is_causal: bool = True
) -> torch.Tensor:
    """Compute attention using PyTorch's SDPA with automatic backend selection.

    Args:
        q: (batch, n_heads, seq_len, head_dim)
        k: (batch, n_kv_heads, seq_len, head_dim)
        v: (batch, n_kv_heads, seq_len, head_dim)
        is_causal: Whether to apply causal masking.

    Returns:
        Output tensor of shape (batch, n_heads, seq_len, head_dim).
    """
    # For GQA: expand KV heads to match query heads
    if k.shape[1] != q.shape[1]:
        n_groups = q.shape[1] // k.shape[1]
        k = k.repeat_interleave(n_groups, dim=1)
        v = v.repeat_interleave(n_groups, dim=1)

    return F.scaled_dot_product_attention(
        q, k, v,
        attn_mask=None,
        dropout_p=0.0,
        is_causal=is_causal
    )

You can verify which backend is being used:

# Check available backends
from torch.backends.cuda import (
    flash_sdp_enabled,
    mem_efficient_sdp_enabled
)

print(f"FlashAttention available: {flash_sdp_enabled()}")
print(f"Memory-efficient attention available: {mem_efficient_sdp_enabled()}")

4. Sliding Window Attention

Motivation

Even with FlashAttention's O(N)O(N) memory, the compute cost of attention remains O(N2)O(N^2) in FLOPs. For very long sequences (100K+ tokens), this quadratic cost becomes significant. Sliding window attention limits each token's attention to a fixed window of WW previous tokens, reducing the compute to O(N×W)O(N \times W).

How It Works

Instead of allowing each token to attend to all previous tokens, sliding window attention restricts the attention span:

AttentionSWA(qi,K,V)=softmax(qiK[iW:i]dk)V[iW:i]\text{Attention}_{SWA}(q_i, K, V) = \text{softmax}\left(\frac{q_i K_{[i-W:i]}^\top}{\sqrt{d_k}}\right) V_{[i-W:i]}

where K[iW:i]K_{[i-W:i]} denotes the keys within the window [max(0,iW),i][\max(0, i-W), i].

For a window size WW and sequence length NN:

PropertyFull AttentionSliding Window
Attention span per tokenAll previous tokensWW previous tokens
FLOPs per layerO(N2d)O(N^2 \cdot d)O(NWd)O(N \cdot W \cdot d)
KV-cache sizeGrows with NNFixed at WW
Effective receptive fieldNN per layerW×LW \times L across LL layers

The key insight for the effective receptive field: with a window size WW and LL layers, information can propagate up to W×LW \times L tokens through the residual connections. Mistral 7B uses W=4096W = 4096 with 32 layers, giving an effective receptive field of 131,072 tokens despite each layer only attending to 4,096 tokens.

Implementation Considerations

Sliding window attention is efficiently supported by FlashAttention-2 and later versions. The causal mask is modified to zero out attention scores outside the window:

Mij={0if iW<jiotherwiseM_{ij} = \begin{cases} 0 & \text{if } i - W < j \leq i \\ -\infty & \text{otherwise} \end{cases}

During inference with a KV-cache, the cache only needs to store the last WW tokens, making the memory usage constant regardless of the total sequence length. This is particularly valuable for long-context inference.

Models Using Sliding Window Attention

ModelWindow SizeLayersEffective Receptive Field
Mistral 7B4,09632131,072
Mixtral 8x7B4,09632131,072
Gemma 2 (alternating)Mix of local and global42Full context

Gemma 2 uses an interesting hybrid: alternating layers of sliding window attention (for efficiency) and full global attention (for long-range dependencies). This gives the best of both worlds.

Summary

InnovationProblem SolvedMechanismAdopted By
RoPEPosition encoding with extrapolationRotation in complex planeLLaMA, Mistral, Qwen, Gemma
GQAKV-cache memory during inferenceShared KV heads across query groupsLLaMA 2/3, Mistral
FlashAttentionMemory bandwidth bottleneckTiling + online softmax in SRAMNearly all modern LLMs
Sliding WindowQuadratic complexity for long sequencesFixed-size attention window per layerMistral, Gemma 2

These modifications are complementary and are typically used together. A modern LLM like LLaMA 3 or Mistral combines RoPE for position encoding, GQA for KV-cache efficiency, and FlashAttention for IO-aware computation -- all within a Pre-LN decoder-only architecture with RMSNorm.


In the next post, we will turn to Part 4: FFN Modifications -- the evolution of the feed-forward network from simple ReLU MLPs to gated architectures like SwiGLU, and how Mixture of Experts (MoE) enables scaling to trillion-parameter models while keeping inference costs manageable.


References

  1. Vaswani, A., Shazeer, N., Parmar, N., et al. (2017). "Attention Is All You Need." NeurIPS 2017.
  2. Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B., Liu, Y. (2021). "RoFormer: Enhanced Transformer with Rotary Position Embedding." arXiv:2104.09864.
  3. Shazeer, N. (2019). "Fast Transformer Decoding: One Write-Head is All You Need." arXiv:1911.02150.
  4. Ainslie, J., Lee-Thorp, J., de Jong, M., Zemlyanskiy, Y., Lebron, F., Sanghai, S. (2023). "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." EMNLP 2023.
  5. Dao, T., Fu, D. Y., Ermon, S., Rudra, A., Re, C. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS 2022.
  6. Dao, T. (2023). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." arXiv:2307.08691.
  7. Press, O., Smith, N. A., Lewis, M. (2022). "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation." ICLR 2022.
  8. Chen, S., Wong, S., Chen, L., Tian, Y. (2023). "Extending Context Window of Large Language Models via Position Interpolation." arXiv:2306.15595.
  9. Jiang, A. Q., Sablayrolles, A., Mensch, A., et al. (2023). "Mistral 7B." arXiv:2310.06825.
  10. Milakov, M. and Gimelshein, N. (2018). "Online Normalizer Calculation for Softmax." arXiv:1805.02867.
Share:
SW

Written by Suchinthaka Wanninayaka

AI/ML Researcher exploring semantic communications, diffusion models, and language model systems. Writing about deep learning from theory to production.

Responses

?

No responses yet. Be the first to share your thoughts!