We have spent seven posts dissecting the transformer: its original design, architectural refinements, attention mechanisms, feed-forward networks, training recipes, inference optimization, and the small-but-important engineering details. All of these improvements share one fundamental limitation: self-attention still computes pairwise interactions between every pair of tokens, giving it O(n2)O(n^2) time and memory complexity in sequence length.

For a 4K context window, this is manageable. For 128K or 1M tokens, it becomes the primary bottleneck. This final post surveys the architectures that have emerged to break past this quadratic wall: State Space Models, Mamba, Linear Attention, RWKV, and the hybrid designs that combine the best of multiple paradigms.

The Quadratic Wall

Self-attention computes:

Attention(Q,K,V)=softmax(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V

The QKQK^\top product creates an n×nn \times n attention matrix. Even with FlashAttention reducing memory from O(n2)O(n^2) to O(n)O(n), the compute remains O(n2d)O(n^2 d). Concretely:

Sequence LengthAttention FLOPs (d=128d=128)Ratio to 4K
4,096~4.3 billion1x
32,768~275 billion64x
131,072~4.4 trillion1,024x
1,048,576~281 trillion65,536x

Doubling the sequence length quadruples the cost. This scaling fundamentally limits how far standard transformers can push context length, even with clever systems engineering.

The alternative architectures we examine here all achieve O(n)O(n) or O(nlogn)O(n \log n) complexity by replacing the dense attention matrix with structured recurrences, linear maps, or state-based computations.

SSM vs Transformer complexity comparison

State Space Models (SSMs)

From Control Theory to Sequence Modeling

State Space Models have a long history in control theory and signal processing. The core idea is simple: maintain a hidden state h(t)h(t) that evolves over time according to a linear differential equation:

dh(t)dt=Ah(t)+Bx(t)\frac{dh(t)}{dt} = Ah(t) + Bx(t) y(t)=Ch(t)+Dx(t)y(t) = Ch(t) + Dx(t)

where:

  • x(t)Rx(t) \in \mathbb{R} is the input signal at time tt
  • h(t)RNh(t) \in \mathbb{R}^N is a hidden state vector of dimension NN (the "state size")
  • y(t)Ry(t) \in \mathbb{R} is the output
  • ARN×NA \in \mathbb{R}^{N \times N} is the state transition matrix governing how the state evolves
  • BRN×1B \in \mathbb{R}^{N \times 1} maps the input into the state
  • CR1×NC \in \mathbb{R}^{1 \times N} reads out from the state to produce output
  • DRD \in \mathbb{R} is a skip connection (often set to zero)

The power of this formulation is that AA determines the memory of the system. Different eigenstructures of AA correspond to different temporal behaviors: oscillatory patterns, exponential decays, or long-range dependencies.

Discretization: From Continuous to Discrete

Neural networks operate on discrete sequences, not continuous signals. We need to convert the continuous ODE into a discrete recurrence using a step size Δ\Delta. The most common approach is the zero-order hold (ZOH) discretization:

Aˉ=exp(ΔA)\bar{A} = \exp(\Delta A) Bˉ=(ΔA)1(exp(ΔA)I)ΔB\bar{B} = (\Delta A)^{-1}(\exp(\Delta A) - I) \cdot \Delta B

This gives us the discrete recurrence:

hk=Aˉhk1+Bˉxkh_k = \bar{A} h_{k-1} + \bar{B} x_k yk=Chk+Dxky_k = C h_k + D x_k

The step size Δ\Delta acts as a resolution parameter. A small Δ\Delta makes the model attend to fine-grained details; a large Δ\Delta makes it focus on coarse, long-range patterns. This provides an intuitive knob absent from attention-based models.

For the special case where AA is diagonal (which most modern SSMs assume), the matrix exponential simplifies to element-wise exponentiation:

Aˉii=exp(ΔAii)\bar{A}_{ii} = \exp(\Delta \cdot A_{ii})

This makes the discretized state transition a simple element-wise operation, avoiding the O(N3)O(N^3) matrix exponential.

The Dual View: Recurrence and Convolution

A remarkable property of linear SSMs is that they have two equivalent computational modes:

Recurrent mode (for inference): Process tokens one at a time, maintaining a hidden state. Cost per token: O(N)O(N), where NN is the state dimension.

hk=Aˉhk1+Bˉxk,yk=Chkh_k = \bar{A} h_{k-1} + \bar{B} x_k, \quad y_k = C h_k

Convolutional mode (for training): Unroll the recurrence into a global convolution. By expanding the recurrence:

y0=CBˉx0y_0 = C\bar{B} x_0 y1=CAˉBˉx0+CBˉx1y_1 = C\bar{A}\bar{B} x_0 + C\bar{B} x_1 yk=j=0kCAˉkjBˉxjy_k = \sum_{j=0}^{k} C\bar{A}^{k-j}\bar{B} x_j

This is a causal convolution y=xKˉy = x * \bar{K} with kernel:

Kˉ=(CBˉ, CAˉBˉ, CAˉ2Bˉ, , CAˉL1Bˉ)\bar{K} = (C\bar{B},\ C\bar{A}\bar{B},\ C\bar{A}^2\bar{B},\ \ldots,\ C\bar{A}^{L-1}\bar{B})

Using FFT, this convolution costs O(LlogL)O(L \log L) for a sequence of length LL, which is far better than the O(L2)O(L^2) of attention. During training we use the convolutional mode for parallelism; during inference we switch to the recurrent mode for O(1)O(1) per-step cost.

S4: The Breakthrough

The Structured State Spaces for Sequence Modeling (S4) paper by Gu et al. (2022) made SSMs practical by solving three critical problems:

Problem 1: Long-range dependencies. A naive SSM with random AA matrix suffers from vanishing gradients over long sequences (the eigenvalues of Aˉk\bar{A}^k decay exponentially). S4 solved this with the HiPPO (High-order Polynomial Projection Operator) initialization:

Ank={(2n+1)1/2(2k+1)1/2if n>kn+1if n=k0if n<kA_{nk} = -\begin{cases} (2n+1)^{1/2}(2k+1)^{1/2} & \text{if } n > k \\ n+1 & \text{if } n = k \\ 0 & \text{if } n < k \end{cases}

This specific matrix structure has the property that the hidden state h(t)h(t) optimally compresses the entire input history x(τ)x(\tau) for τt\tau \leq t by projecting it onto a basis of Legendre polynomials. This is not an arbitrary choice; it is provably optimal for a certain class of approximation problems.

Problem 2: Efficient kernel computation. Computing Kˉ\bar{K} naively requires LL matrix-vector products with Aˉ\bar{A}. S4 showed that when AA has special structure (diagonal plus low-rank, or purely diagonal in later work like S4D and S5), the kernel can be computed in O(NlogN)O(N \log N) time using a generating function approach.

Problem 3: Stability. Training a system with a recurrence hk=Aˉhk1+h_k = \bar{A}h_{k-1} + \ldots is inherently unstable if eigenvalues of Aˉ\bar{A} exceed 1. S4 parameterizes AA to guarantee stability by constraining eigenvalues to have negative real parts (in continuous time) or magnitude less than 1 (in discrete time).

SSM Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class S4DLayer(nn.Module):
    """
    Simplified S4D (diagonal) layer.
    Uses diagonal state matrix for efficient computation.
    """

    def __init__(self, d_model: int, state_dim: int = 64, dt_min: float = 0.001,
                 dt_max: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.state_dim = state_dim

        # HiPPO-inspired diagonal initialization (S4D-Lin variant)
        # Real part: negative half-integers for stability
        A_real = -0.5 * torch.ones(d_model, state_dim)
        # Imaginary part: frequencies for oscillatory modes
        A_imag = math.pi * torch.arange(state_dim).float().repeat(d_model, 1)
        self.A_log_real = nn.Parameter(torch.log(-A_real))  # Log for positivity
        self.A_imag = nn.Parameter(A_imag)

        # Input and output projections (complex-valued)
        self.B_re = nn.Parameter(torch.randn(d_model, state_dim) * 0.5)
        self.B_im = nn.Parameter(torch.randn(d_model, state_dim) * 0.5)
        self.C_re = nn.Parameter(torch.randn(d_model, state_dim) * 0.5)
        self.C_im = nn.Parameter(torch.randn(d_model, state_dim) * 0.5)

        # Learnable discretization step size
        log_dt = torch.rand(d_model) * (
            math.log(dt_max) - math.log(dt_min)
        ) + math.log(dt_min)
        self.log_dt = nn.Parameter(log_dt)

        self.D = nn.Parameter(torch.ones(d_model))  # Skip connection

    def get_kernel(self, length: int) -> torch.Tensor:
        """Compute the convolution kernel of given length."""
        dt = torch.exp(self.log_dt)  # (d_model,)
        A_real = -torch.exp(self.A_log_real)  # Negative for stability
        A = torch.complex(A_real, self.A_imag)  # (d_model, state_dim)
        B = torch.complex(self.B_re, self.B_im)
        C = torch.complex(self.C_re, self.C_im)

        # Discretize: A_bar = exp(dt * A)
        dt_A = dt.unsqueeze(-1) * A  # (d_model, state_dim)
        A_bar = torch.exp(dt_A)
        B_bar = (A_bar - 1.0) / (dt_A + 1e-8) * (dt.unsqueeze(-1) * B)

        # Build kernel: K[k] = C @ A_bar^k @ B_bar
        # Using geometric series for efficiency
        powers = torch.arange(length, device=A.device).unsqueeze(0).unsqueeze(0)
        # A_bar^k for each position
        A_bar_k = A_bar.unsqueeze(-1) ** powers  # (d_model, state_dim, length)
        kernel = torch.einsum("dn,dnl->dl", C * B_bar, A_bar_k)

        return kernel.real  # (d_model, length)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, length, d_model)
        Returns:
            y: (batch, length, d_model)
        """
        batch, length, _ = x.shape

        # Compute kernel
        kernel = self.get_kernel(length)  # (d_model, length)

        # Convolutional mode: use FFT for O(L log L) computation
        x_t = x.transpose(1, 2)  # (batch, d_model, length)

        # Pad for causal convolution
        x_padded = F.pad(x_t, (kernel.size(-1) - 1, 0))

        # FFT convolution
        k_f = torch.fft.rfft(kernel, n=x_padded.size(-1))
        x_f = torch.fft.rfft(x_padded, n=x_padded.size(-1))
        y = torch.fft.irfft(x_f * k_f.unsqueeze(0), n=x_padded.size(-1))
        y = y[..., :length]  # Trim to original length

        # Skip connection
        y = y + self.D.unsqueeze(0).unsqueeze(-1) * x_t

        return y.transpose(1, 2)  # (batch, length, d_model)

Mamba: Selective State Spaces

The Limitation of Linear SSMs

All the models discussed so far (S4, S4D, S5) share a critical limitation: the state transition matrices AA, BB, CC are fixed (independent of the input). This means the model treats every input token identically in terms of how it updates and reads the state.

Consider processing the sentence "The capital of France is Paris." A fixed SSM applies the same dynamics to "The" (uninformative) and "France" (critical). It cannot selectively store or retrieve information based on content. This is fundamentally different from attention, where the query-key mechanism explicitly selects which tokens to attend to.

Mamba's Key Innovation: Input-Dependent Parameters

Mamba (Gu and Dao, 2024) solves this by making BB, CC, and Δ\Delta functions of the input:

Bt=LinearB(xt),Ct=LinearC(xt),Δt=softplus(LinearΔ(xt))B_t = \text{Linear}_B(x_t), \quad C_t = \text{Linear}_C(x_t), \quad \Delta_t = \text{softplus}(\text{Linear}_\Delta(x_t))

Now the model can:

  • Selectively remember: When processing an important token, it can set Δt\Delta_t large, causing Aˉt=exp(ΔtA)\bar{A}_t = \exp(\Delta_t A) to push eigenvalues toward zero, effectively resetting the state and writing new information via Bˉtxt\bar{B}_t x_t.
  • Selectively ignore: For unimportant tokens, a small Δt\Delta_t keeps Aˉt\bar{A}_t close to the identity, preserving the existing state and barely incorporating the new input.
  • Selectively output: The input-dependent CtC_t controls what information is read from the state at each step.

This selectivity mechanism provides content-aware processing comparable to attention, but within a recurrent framework that maintains O(n)O(n) complexity.

The Hardware-Aware Parallel Scan

Making parameters input-dependent breaks the convolution view: since Aˉt\bar{A}_t varies with tt, we can no longer express the computation as a fixed convolution. However, we can still parallelize using the parallel scan algorithm.

The recurrence ht=Aˉtht1+Bˉtxth_t = \bar{A}_t h_{t-1} + \bar{B}_t x_t is an instance of a first-order linear recurrence, which can be computed in O(n)O(n) work and O(logn)O(\log n) depth using the parallel prefix sum (scan) algorithm. The key insight is that pairs (a,b)(a, b) representing h=ahprev+bh = a \cdot h_{prev} + b can be composed associatively:

(a2,b2)(a1,b1)=(a2a1, a2b1+b2)(a_2, b_2) \circ (a_1, b_1) = (a_2 \cdot a_1, \ a_2 \cdot b_1 + b_2)

Mamba implements this with a custom CUDA kernel that is carefully designed for GPU memory hierarchy:

  1. Load from HBM to SRAM: The input xx, discretized parameters Aˉ\bar{A}, Bˉ\bar{B}, and CC are loaded once.
  2. Scan in SRAM: The parallel scan runs entirely in on-chip SRAM, avoiding repeated HBM reads.
  3. Write output to HBM: Only the final output yy is written back.

This is directly analogous to the IO-awareness principle behind FlashAttention, applied to a different computation.

Mamba Architecture

Each Mamba block replaces the attention + FFN pair of a transformer layer:

Input x
  |
  v
Linear (expand by factor E=2)
  |
  +--> Branch A: Conv1d -> SiLU -> Selective SSM --+
  |                                                 |
  +--> Branch B: SiLU (gate) ----------------------+
                                                    |
                                                    v
                                              Element-wise multiply
                                                    |
                                                    v
                                              Linear (project back)
                                                    |
                                                    v
                                                Output y

This gated architecture is reminiscent of the GLU family we discussed in Part 4. The gate branch (Branch B) modulates the SSM output (Branch A), providing an additional nonlinear mixing mechanism.

Complete Mamba Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange


class SelectiveSSM(nn.Module):
    """
    Selective State Space Model -- the core of Mamba.
    Implements input-dependent B, C, and delta with a parallel scan.
    """

    def __init__(self, d_inner: int, d_state: int = 16, dt_rank: int = None,
                 dt_min: float = 0.001, dt_max: float = 0.1, dt_init: str = "random"):
        super().__init__()
        self.d_inner = d_inner
        self.d_state = d_state
        self.dt_rank = dt_rank or math.ceil(d_inner / 16)

        # Input-dependent projections for B, C, and dt
        self.x_proj = nn.Linear(d_inner, self.dt_rank + 2 * d_state, bias=False)
        self.dt_proj = nn.Linear(self.dt_rank, d_inner, bias=True)

        # Initialize dt bias so that softplus(bias) is in [dt_min, dt_max]
        dt = torch.exp(
            torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        )
        inv_dt = dt + torch.log(-torch.expm1(-dt))  # softplus inverse
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)

        # State matrix A: initialized as negative real values (log scale)
        A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(d_inner, 1)
        self.A_log = nn.Parameter(torch.log(A))

        # D: skip connection
        self.D = nn.Parameter(torch.ones(d_inner))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, length, d_inner) -- post-convolution, post-SiLU
        Returns:
            y: (batch, length, d_inner)
        """
        batch, length, d_inner = x.shape
        d_state = self.d_state

        # Compute input-dependent B, C, dt
        x_dbl = self.x_proj(x)  # (batch, length, dt_rank + 2*d_state)
        dt, B, C = x_dbl.split([self.dt_rank, d_state, d_state], dim=-1)

        dt = self.dt_proj(dt)  # (batch, length, d_inner)
        dt = F.softplus(dt)    # Ensure positive step sizes

        # Discretize A
        A = -torch.exp(self.A_log.float())  # (d_inner, d_state), negative for stability

        # Compute discretized parameters
        # dA = exp(dt * A) -- element-wise for diagonal A
        dA = torch.exp(
            torch.einsum("bld,dn->bldn", dt, A)
        )  # (batch, length, d_inner, d_state)

        # dB = dt * B (simplified ZOH for small dt)
        dB_x = torch.einsum("bld,bln->bldn", dt * x, B)
        # (batch, length, d_inner, d_state)

        # Run the selective scan
        y = self.parallel_scan(dA, dB_x, C)

        # Skip connection
        y = y + x * self.D.unsqueeze(0).unsqueeze(0)

        return y

    def parallel_scan(
        self,
        dA: torch.Tensor,
        dB_x: torch.Tensor,
        C: torch.Tensor,
    ) -> torch.Tensor:
        """
        Implements the parallel scan for the selective SSM recurrence.

        Recurrence: h_t = dA_t * h_{t-1} + dB_x_t
        Output:     y_t = C_t @ h_t

        This is a simplified (pure PyTorch) implementation. The real Mamba uses
        custom CUDA kernels for hardware-aware execution.

        Args:
            dA:   (batch, length, d_inner, d_state) -- discretized state transition
            dB_x: (batch, length, d_inner, d_state) -- discretized input contribution
            C:    (batch, length, d_state)           -- output projection
        Returns:
            y:    (batch, length, d_inner)
        """
        batch, length, d_inner, d_state = dA.shape

        # Sequential scan (reference implementation)
        # In practice, this would use a work-efficient parallel scan on GPU
        h = torch.zeros(batch, d_inner, d_state, device=dA.device, dtype=dA.dtype)
        ys = []

        for t in range(length):
            h = dA[:, t] * h + dB_x[:, t]  # (batch, d_inner, d_state)
            y_t = torch.einsum("bdn,bn->bd", h, C[:, t])  # (batch, d_inner)
            ys.append(y_t)

        return torch.stack(ys, dim=1)  # (batch, length, d_inner)


class MambaBlock(nn.Module):
    """
    Full Mamba block: replaces the Attention + FFN pair of a transformer layer.

    Architecture:
        x -> Linear(expand) -> split into two branches
        Branch A: Conv1d -> SiLU -> SelectiveSSM
        Branch B: SiLU (gate)
        Output: (Branch A * Branch B) -> Linear(project)
    """

    def __init__(self, d_model: int, d_state: int = 16, expand: int = 2,
                 d_conv: int = 4):
        super().__init__()
        self.d_model = d_model
        self.d_inner = d_model * expand
        self.d_state = d_state
        self.d_conv = d_conv

        # Input projection: expand and split into x and gate
        self.in_proj = nn.Linear(d_model, 2 * self.d_inner, bias=False)

        # Depthwise convolution for local context
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            kernel_size=d_conv,
            padding=d_conv - 1,
            groups=self.d_inner,  # Depthwise
            bias=True,
        )

        # Selective SSM
        self.ssm = SelectiveSSM(
            d_inner=self.d_inner,
            d_state=d_state,
        )

        # Output projection: compress back to d_model
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)

        # Normalization
        self.norm = nn.RMSNorm(d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, length, d_model)
        Returns:
            (batch, length, d_model)
        """
        residual = x
        x = self.norm(x)

        # Project and split
        xz = self.in_proj(x)  # (batch, length, 2 * d_inner)
        x_branch, z = xz.chunk(2, dim=-1)  # Each: (batch, length, d_inner)

        # Branch A: conv -> activation -> SSM
        x_branch = x_branch.transpose(1, 2)  # (batch, d_inner, length)
        x_branch = self.conv1d(x_branch)[:, :, :x.size(1)]  # Causal: trim future
        x_branch = x_branch.transpose(1, 2)  # (batch, length, d_inner)
        x_branch = F.silu(x_branch)
        x_branch = self.ssm(x_branch)

        # Branch B: gate
        z = F.silu(z)

        # Combine and project
        y = x_branch * z
        y = self.out_proj(y)

        return y + residual


class MambaModel(nn.Module):
    """A complete Mamba language model."""

    def __init__(self, vocab_size: int, d_model: int, n_layers: int,
                 d_state: int = 16, expand: int = 2, d_conv: int = 4):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            MambaBlock(d_model, d_state, expand, d_conv)
            for _ in range(n_layers)
        ])
        self.norm_f = nn.RMSNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

        # Weight tying (common in Mamba models)
        self.lm_head.weight = self.embedding.weight

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        x = self.embedding(input_ids)
        for layer in self.layers:
            x = layer(x)
        x = self.norm_f(x)
        return self.lm_head(x)

Mamba's Performance Profile

Mamba's advantages are most pronounced during inference. Since the SSM state has a fixed size (dinner×dstated_{inner} \times d_{state}, typically 2048×16=32,7682048 \times 16 = 32{,}768 values), the per-token inference cost is constant regardless of sequence length. Compare this to a transformer's KV cache, which grows linearly with context:

MetricTransformer (1.4B)Mamba (1.4B)
Training throughput1x (baseline)~1x (comparable)
Inference throughput1x (baseline)~5x faster
Memory per token (inference)O(nd)O(n \cdot d) KV cacheO(dN)O(d \cdot N) fixed state
Perplexity (Pile)~14.2~14.0
Long-range retrievalExcellentGood (not perfect)

The one area where Mamba has shown weakness compared to transformers of equal size is in-context learning and retrieval tasks that require precise copying from the context. This is because the fixed-size state must compress all past information, making exact recall difficult for very specific tokens.

Linear Attention

Removing the Softmax

Standard attention requires materializing the n×nn \times n attention matrix because the softmax operates row-wise, coupling all key-query interactions:

Attn(Q,K,V)=softmax(QKdk)V\text{Attn}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V

Linear attention (Katharopoulos et al., 2020) replaces softmax with a decomposable kernel:

LinearAttn(Q,K,V)=ϕ(Q)(ϕ(K)V)ϕ(Q)(ϕ(K)1)\text{LinearAttn}(Q, K, V) = \frac{\phi(Q)\left(\phi(K)^\top V\right)}{\phi(Q)\left(\phi(K)^\top \mathbf{1}\right)}

where ϕ()\phi(\cdot) is a feature map applied element-wise to queries and keys.

The Associativity Trick

The computational savings come from exploiting matrix multiplication associativity. Consider the unnormalized case:

Standard (left-to-right):

(ϕ(Q)ϕ(K))n×nVO(n2d)\underbrace{(\phi(Q) \cdot \phi(K)^\top)}_{n \times n} \cdot V \quad \Rightarrow \quad O(n^2 d)

Linear (right-to-left):

ϕ(Q)(ϕ(K)V)d×dO(nd2)\phi(Q) \cdot \underbrace{(\phi(K)^\top \cdot V)}_{d \times d} \quad \Rightarrow \quad O(n d^2)

By computing S=ϕ(K)VS = \phi(K)^\top V first (a d×dd \times d matrix), then multiplying each query row by SS, we avoid constructing the n×nn \times n attention matrix entirely. When dnd \ll n (which is typically true for long sequences), this is a massive saving.

For causal (autoregressive) attention, we need to ensure token tt only attends to tokens t\leq t. This is achieved with a cumulative sum formulation:

St=St1+ϕ(kt)vtS_t = S_{t-1} + \phi(k_t) v_t^\top yt=ϕ(qt)Sty_t = \phi(q_t) S_t

This is a recurrence on the d×dd \times d state matrix StS_t, making it compatible with efficient inference.

Feature Maps

The choice of ϕ\phi determines the quality-efficiency tradeoff:

Feature MapDefinitionProperties
Identityϕ(x)=x\phi(x) = xSimplest; can produce negative attention weights
ELU+1ϕ(x)=ELU(x)+1\phi(x) = \text{ELU}(x) + 1Positive; smooth; simple to implement
ReLUϕ(x)=max(0,x)\phi(x) = \max(0, x)Positive; sparse; may lose information
Random Fourierϕ(x)=1m[cos(Wx),sin(Wx)]\phi(x) = \frac{1}{\sqrt{m}} [\cos(Wx), \sin(Wx)]Approximates softmax kernel; unbiased
Performer (FAVOR+)Positive random featuresUnbiased softmax approximation with positivity

Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F


class LinearAttention(nn.Module):
    """
    Linear attention with causal masking.
    Uses the ELU+1 feature map for simplicity and positivity.
    """

    def __init__(self, d_model: int, n_heads: int, feature_map: str = "elu"):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.d_model = d_model

        self.wq = nn.Linear(d_model, d_model, bias=False)
        self.wk = nn.Linear(d_model, d_model, bias=False)
        self.wv = nn.Linear(d_model, d_model, bias=False)
        self.wo = nn.Linear(d_model, d_model, bias=False)

        self.feature_map = feature_map

    def phi(self, x: torch.Tensor) -> torch.Tensor:
        """Apply feature map to ensure positive values."""
        if self.feature_map == "elu":
            return F.elu(x) + 1.0
        elif self.feature_map == "relu":
            return F.relu(x)
        elif self.feature_map == "identity":
            return x
        else:
            raise ValueError(f"Unknown feature map: {self.feature_map}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Causal linear attention using the cumulative sum formulation.

        Args:
            x: (batch, length, d_model)
        Returns:
            (batch, length, d_model)
        """
        batch, length, _ = x.shape

        q = self.wq(x).view(batch, length, self.n_heads, self.head_dim)
        k = self.wk(x).view(batch, length, self.n_heads, self.head_dim)
        v = self.wv(x).view(batch, length, self.n_heads, self.head_dim)

        # Apply feature map
        q = self.phi(q)  # (batch, length, heads, head_dim)
        k = self.phi(k)

        # Causal linear attention via cumulative sum
        # S_t = sum_{i<=t} phi(k_i) @ v_i^T  (accumulated KV state)
        # z_t = sum_{i<=t} phi(k_i)           (accumulated normalizer)
        # y_t = (q_t @ S_t) / (q_t @ z_t)

        # Compute outer products k_i @ v_i^T for all positions
        kv = torch.einsum("blhd,blhe->blhde", k, v)  # (B, L, H, D, D)
        # Cumulative sum along sequence dimension
        S = torch.cumsum(kv, dim=1)  # (B, L, H, D, D)

        # Normalizer: cumulative sum of keys
        z = torch.cumsum(k, dim=1)  # (B, L, H, D)

        # Compute output
        y_num = torch.einsum("blhd,blhde->blhe", q, S)  # (B, L, H, D)
        y_den = torch.einsum("blhd,blhd->blh", q, z)     # (B, L, H)
        y_den = y_den.unsqueeze(-1).clamp(min=1e-6)       # Prevent division by zero

        y = y_num / y_den  # (B, L, H, D)

        y = y.reshape(batch, length, self.d_model)
        return self.wo(y)

Limitations of Linear Attention

Despite the appealing complexity, linear attention has well-documented limitations:

  1. No sharp attention patterns. Softmax attention can assign nearly all weight to a single token (approaching a one-hot distribution). Linear attention with positive feature maps produces smoother distributions, making it harder for the model to perform precise retrieval.

  2. The dilution problem. The cumulative state StS_t accumulates all past KV pairs without discounting. Over very long sequences, the state becomes an average over many tokens, diluting individual contributions. Softmax attention avoids this because it re-normalizes at every step.

  3. Quality gap at moderate scale. Empirically, linear attention models tend to underperform softmax attention models of the same size by 1-3 perplexity points, though the gap narrows with scale.

These limitations motivated the development of more sophisticated approaches like RetNet, GLA (Gated Linear Attention), and the hybrid architectures discussed below.

RWKV: Transformer-Level Performance, RNN-Level Efficiency

Design Philosophy

RWKV (Peng et al., 2023) takes a different approach from SSMs. Rather than starting from control theory, it directly reformulates the attention mechanism to be computable as a linear recurrence. The name RWKV comes from the four core elements: Receptance (R), Weight (W), Key (K), and Value (V).

The WKV Mechanism

The core computation in RWKV is the "WKV" (Weighted Key-Value) operator:

wkvt=i=1t1e(t1i)w+kivi+eu+ktvti=1t1e(t1i)w+ki+eu+kt\text{wkv}_t = \frac{\sum_{i=1}^{t-1} e^{-(t-1-i)w + k_i} v_i + e^{u + k_t} v_t}{\sum_{i=1}^{t-1} e^{-(t-1-i)w + k_i} + e^{u + k_t}}

where:

  • wRdw \in \mathbb{R}^d is a learned per-channel decay rate (always positive, so past tokens are exponentially downweighted)
  • uRdu \in \mathbb{R}^d is a learned bonus for the current token (allowing the model to upweight the current position)
  • kt,vtk_t, v_t are key and value projections of the input at position tt

This can be computed recurrently by maintaining two running sums:

at=ewat1+ektvta_t = e^{-w} \cdot a_{t-1} + e^{k_t} \cdot v_t bt=ewbt1+ektb_t = e^{-w} \cdot b_{t-1} + e^{k_t} wkvt=euektvt+at1euekt+bt1\text{wkv}_t = \frac{e^u \cdot e^{k_t} \cdot v_t + a_{t-1}}{e^u \cdot e^{k_t} + b_{t-1}}

The exponential decay ewe^{-w} is the key to RWKV's long-range capability: by learning different decay rates per channel, the model can maintain both short-term and long-term memory across different dimensions of the hidden state.

RWKV Block Structure

An RWKV block consists of two sub-blocks:

Input x
  |
  v
LayerNorm -> Time Mixing (WKV) -> + (residual) -> LayerNorm -> Channel Mixing -> + (residual)

Time mixing handles sequence-level interactions (analogous to attention):

rt=Wr(μrxt+(1μr)xt1)r_t = W_r \cdot (\mu_r \odot x_t + (1 - \mu_r) \odot x_{t-1}) kt=Wk(μkxt+(1μk)xt1)k_t = W_k \cdot (\mu_k \odot x_t + (1 - \mu_k) \odot x_{t-1}) vt=Wv(μvxt+(1μv)xt1)v_t = W_v \cdot (\mu_v \odot x_t + (1 - \mu_v) \odot x_{t-1}) ot=σ(rt)wkvto_t = \sigma(r_t) \odot \text{wkv}_t

The μ\mu parameters control a token-shift mechanism: each projection sees a learned interpolation between the current token and the previous one, providing a simple form of local context without convolution.

Channel mixing handles feature-level interactions (analogous to FFN):

rt=Wr(μrxt+(1μr)xt1)r_t = W_r \cdot (\mu_r \odot x_t + (1 - \mu_r) \odot x_{t-1}) kt=Wk(μkxt+(1μk)xt1)k_t = W_k \cdot (\mu_k \odot x_t + (1 - \mu_k) \odot x_{t-1}) ot=σ(rt)(Wvmax(kt,0)2)o_t = \sigma(r_t) \odot (W_v \cdot \max(k_t, 0)^2)

The squared ReLU activation provides the nonlinearity, while the receptance gate σ(rt)\sigma(r_t) controls information flow.

RWKV vs. Attention: A Structural Comparison

PropertySoftmax AttentionRWKV
Sequence interactionAll-pairs (n2n^2)Recurrent (nn)
Memory at inferenceKV cache grows with nnFixed-size state
TrainingParallel (matmul)Parallel (custom kernel)
Positional awarenessExplicit (RoPE, etc.)Implicit (decay + token shift)
Retrieval precisionExcellent (sharp attention)Good (exponential decay)
Very long range (>100K)Possible but expensiveNatural and efficient

RWKV has been scaled to 14B parameters (RWKV-5 "Eagle" and RWKV-6 "Finch"), achieving performance competitive with similarly-sized transformers on standard benchmarks while offering constant-memory inference.

Hybrid Architectures: The Best of Both Worlds

Why Hybrids?

Each architecture excels in different regimes:

  • Attention is best for tasks requiring precise retrieval, in-context learning, and complex reasoning over specific context elements.
  • SSM/Mamba excels at long-range dependencies, efficient inference, and streaming applications.
  • Linear methods offer the best theoretical complexity but sacrifice some expressive power.

Hybrid architectures combine these strengths by using different mechanisms in different layers.

Jamba (AI21 Labs)

Jamba (Lieber et al., 2024) is the first production-scale hybrid Mamba-Transformer model. Its architecture interleaves Mamba layers with attention layers and adds Mixture of Experts:

Layer IndexTypeMoE
0AttentionYes (top-2 of 16 experts)
1-6MambaYes
7AttentionYes
8-13MambaYes
14AttentionYes
.........

The ratio is approximately 1 attention layer for every 7 Mamba layers. The attention layers handle precise retrieval tasks while the Mamba layers efficiently process the bulk of the sequence context. Jamba achieves:

  • 256K context length
  • 3x throughput improvement over a pure transformer of similar quality
  • Comparable quality to Mixtral 8x7B on standard benchmarks
  • Significantly reduced KV cache size (only the few attention layers need a cache)

Mamba-2: Structured State Space Duality

Mamba-2 (Dao and Gu, 2024) establishes a deep theoretical connection between SSMs and attention, which they call the State Space Duality (SSD) framework. The key insight is that the selective SSM computation can be expressed as a specific form of structured (semiseparable) matrix multiplication:

y=M(xBx)y = M \cdot (x \odot B_x)

where MM is a semiseparable matrix determined by the discretized state transition. This matrix has the same structure as a causal attention matrix with specific masking patterns.

This duality enables:

  1. Larger state dimensions (64 or 128 instead of 16) because the computation maps to efficient matrix multiplies on modern GPUs.
  2. Direct architectural interleaving with attention layers, since both operate through matrix multiplications of compatible shapes.
  3. 2-8x speedup over Mamba-1 through better hardware utilization.

Practical Hybrid Design Patterns

Based on published architectures, several hybrid design patterns have emerged:

Pattern 1: Sparse Interleaving (Jamba) Place attention layers at regular intervals (every 6-8 layers). The attention layers act as "global synchronization points" that can perform precise retrieval, while Mamba layers handle local and medium-range processing efficiently.

Pattern 2: First-and-Last (StripedHyena) Use attention in the first and last few layers, with SSM layers in the middle. The intuition is that early layers need to establish token-level representations (which benefits from attention), and final layers need precise output selection, while middle layers primarily propagate and transform features.

Pattern 3: Parallel Hybrid Run attention and SSM in parallel within the same layer and add their outputs. This preserves both pathways without increasing depth. Some Mamba-2 variants explore this pattern.

import torch
import torch.nn as nn


class HybridBlock(nn.Module):
    """
    A flexible hybrid block that can be configured as
    Mamba-only, Attention-only, or parallel hybrid.
    """

    def __init__(self, d_model: int, n_heads: int, d_ff: int,
                 d_state: int = 16, mode: str = "mamba"):
        super().__init__()
        self.mode = mode
        self.norm = nn.RMSNorm(d_model)

        if mode in ("attention", "parallel"):
            self.attn = BiasFreeLLaMAAttention(d_model, n_heads, n_heads)
            self.attn_norm = nn.RMSNorm(d_model)

        if mode in ("mamba", "parallel"):
            self.mamba = MambaBlock(d_model, d_state=d_state)

        self.ffn = BiasFreeSwiGLU(d_model, d_ff)
        self.ffn_norm = nn.RMSNorm(d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.mode == "attention":
            x = x + self.attn(self.norm(x))
        elif self.mode == "mamba":
            x = self.mamba(x)  # MambaBlock includes its own norm + residual
            x = x + self.ffn(self.ffn_norm(x))
            return x
        elif self.mode == "parallel":
            normed = self.norm(x)
            x = x + self.attn(normed) + self.mamba.ssm(normed)

        x = x + self.ffn(self.ffn_norm(x))
        return x


class JambaStyleModel(nn.Module):
    """
    Simplified Jamba-style hybrid: interleave attention and Mamba layers.
    Every `attn_every` layers uses attention; others use Mamba.
    """

    def __init__(self, vocab_size: int, d_model: int, n_layers: int,
                 n_heads: int, d_ff: int, d_state: int = 16,
                 attn_every: int = 7):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)

        self.layers = nn.ModuleList()
        for i in range(n_layers):
            if i % attn_every == 0:
                self.layers.append(
                    HybridBlock(d_model, n_heads, d_ff, d_state, mode="attention")
                )
            else:
                self.layers.append(
                    HybridBlock(d_model, n_heads, d_ff, d_state, mode="mamba")
                )

        self.norm_f = nn.RMSNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        x = self.embedding(input_ids)
        for layer in self.layers:
            x = layer(x)
        x = self.norm_f(x)
        return self.lm_head(x)

Practical Decision Matrix

When choosing an architecture for a new project, the decision depends on your specific constraints:

CriterionBest ChoiceRunner-UpAvoid
Quality at any costTransformer (dense attention)Hybrid (Jamba-style)Pure linear attention
Very long context (>128K)Mamba / HybridRWKVPure transformer
Low-latency inferenceMambaRWKVTransformer (large KV cache)
Streaming / real-timeMamba, RWKVHybridTransformer
In-context learningTransformerHybridPure SSM
Retrieval-heavy tasksTransformerHybridMamba (weaker at exact recall)
Memory-constrained deploymentRWKV, MambaQuantized transformerFull-precision transformer
Largest proven scale (>100B)TransformerHybrid (early results)Pure Mamba (not yet proven)
Research flexibilityTransformer (most tooling)Mamba (growing ecosystem)RWKV (smaller community)

The practical recommendation in 2025 is:

  1. Default to transformers if you need the best quality and your context length is under 32K. The ecosystem (frameworks, quantization tools, serving infrastructure) is most mature.
  2. Use hybrids if you need long context (>64K) with high quality. The combination of a few attention layers with many SSM layers gives the best quality-efficiency tradeoff.
  3. Use pure Mamba if your primary constraint is inference throughput or memory, and you can accept a small quality tradeoff.
  4. Use RWKV if you need an open-source, well-tested alternative with RNN-level inference efficiency and an active community.

Conclusion: The Journey Through Transformers

This post concludes the Transformer Deep Dive series. Let us step back and survey what we have covered across all eight parts.

Part 1 introduced the original "Attention Is All You Need" architecture: encoder-decoder structure, multi-head self-attention, sinusoidal position encodings, and the fundamental insight that attention over sequences could replace recurrence entirely.

Part 2 traced the architectural evolution: the shift from encoder-decoder to decoder-only models, Pre-LayerNorm for training stability, and RMSNorm for computational efficiency. These changes seem incremental but were essential for scaling to billions of parameters.

Part 3 examined attention itself in detail: Rotary Position Embeddings (RoPE) that encode relative positions through rotation, Grouped-Query Attention (GQA) that balances quality and KV cache size, and FlashAttention that made exact attention practical through IO-aware computation.

Part 4 dissected the feed-forward network: SwiGLU activation providing better gradient flow, and Mixture of Experts (MoE) enabling dramatic scaling of model capacity without proportional compute increase.

Part 5 covered the training machinery: AdamW optimizer with decoupled weight decay, mixed-precision training with BF16, learning rate schedules, and the data pipeline decisions that determine model quality.

Part 6 focused on inference optimization: KV caching, quantization from FP16 down to INT4, speculative decoding for latency reduction, and continuous batching for throughput.

Part 7 catalogued the small but important details: bias removal, tied vs. untied embeddings, parallel attention and FFN, initialization schemes, and the numerical engineering that makes everything work reliably at scale.

Part 8 (this post) looked beyond attention to State Space Models, Mamba, Linear Attention, RWKV, and hybrid architectures that challenge the assumption that O(n2)O(n^2) attention is necessary.

The field continues to evolve rapidly. As of this writing, the most exciting developments lie in three directions: hybrid architectures that combine attention's expressiveness with SSM efficiency; hardware-algorithm co-design that optimizes both the silicon and the math simultaneously; and scaling laws for alternative architectures that tell us how these models behave as we push them to hundreds of billions of parameters.

What began with "Attention Is All You Need" in 2017 has grown into a rich landscape of architectural innovations. Attention may not be all you need after all -- but understanding it deeply, along with the alternatives, is essential for anyone building or researching modern AI systems.

References

  • Vaswani, A. et al. "Attention Is All You Need." NeurIPS, 2017.
  • Gu, A. et al. "Efficiently Modeling Long Sequences with Structured State Spaces." ICLR, 2022.
  • Gu, A. et al. "On the Parameterization and Initialization of Diagonal State Space Models." NeurIPS, 2022.
  • Gu, A. and Dao, T. "Mamba: Linear-Time Sequence Modeling with Selective State Spaces." arXiv:2312.00752, 2023.
  • Dao, T. and Gu, A. "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality." ICML, 2024.
  • Katharopoulos, A. et al. "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention." ICML, 2020.
  • Choromanski, K. et al. "Rethinking Attention with Performers." ICLR, 2021.
  • Peng, B. et al. "RWKV: Reinventing RNNs for the Transformer Era." EMNLP Findings, 2023.
  • Peng, B. et al. "Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence." arXiv:2404.05892, 2024.
  • Lieber, O. et al. "Jamba: A Hybrid Transformer-Mamba Language Model." arXiv:2403.19887, 2024.
  • De, S. et al. "Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models." arXiv:2402.19427, 2024.
  • Poli, M. et al. "Hyena Hierarchy: Towards Larger Convolutional Language Models." ICML, 2023.
  • Sun, Y. et al. "Retentive Network: A Successor to Transformer for Large Language Models." arXiv:2307.08621, 2023.
  • Yang, S. et al. "Gated Linear Attention Transformers with Hardware-Efficient Training." ICML, 2024.
Share:
SW

Written by Suchinthaka Wanninayaka

AI/ML Researcher exploring semantic communications, diffusion models, and language model systems. Writing about deep learning from theory to production.

Responses

?

No responses yet. Be the first to share your thoughts!