Transformer Deep Dive: Part 8 - Alternative Architectures
We have spent seven posts dissecting the transformer: its original design, architectural refinements, attention mechanisms, feed-forward networks, training recipes, inference optimization, and the small-but-important engineering details. All of these improvements share one fundamental limitation: self-attention still computes pairwise interactions between every pair of tokens, giving it time and memory complexity in sequence length.
For a 4K context window, this is manageable. For 128K or 1M tokens, it becomes the primary bottleneck. This final post surveys the architectures that have emerged to break past this quadratic wall: State Space Models, Mamba, Linear Attention, RWKV, and the hybrid designs that combine the best of multiple paradigms.
The Quadratic Wall
Self-attention computes:
The product creates an attention matrix. Even with FlashAttention reducing memory from to , the compute remains . Concretely:
| Sequence Length | Attention FLOPs () | Ratio to 4K |
|---|---|---|
| 4,096 | ~4.3 billion | 1x |
| 32,768 | ~275 billion | 64x |
| 131,072 | ~4.4 trillion | 1,024x |
| 1,048,576 | ~281 trillion | 65,536x |
Doubling the sequence length quadruples the cost. This scaling fundamentally limits how far standard transformers can push context length, even with clever systems engineering.
The alternative architectures we examine here all achieve or complexity by replacing the dense attention matrix with structured recurrences, linear maps, or state-based computations.
State Space Models (SSMs)
From Control Theory to Sequence Modeling
State Space Models have a long history in control theory and signal processing. The core idea is simple: maintain a hidden state that evolves over time according to a linear differential equation:
where:
- is the input signal at time
- is a hidden state vector of dimension (the "state size")
- is the output
- is the state transition matrix governing how the state evolves
- maps the input into the state
- reads out from the state to produce output
- is a skip connection (often set to zero)
The power of this formulation is that determines the memory of the system. Different eigenstructures of correspond to different temporal behaviors: oscillatory patterns, exponential decays, or long-range dependencies.
Discretization: From Continuous to Discrete
Neural networks operate on discrete sequences, not continuous signals. We need to convert the continuous ODE into a discrete recurrence using a step size . The most common approach is the zero-order hold (ZOH) discretization:
This gives us the discrete recurrence:
The step size acts as a resolution parameter. A small makes the model attend to fine-grained details; a large makes it focus on coarse, long-range patterns. This provides an intuitive knob absent from attention-based models.
For the special case where is diagonal (which most modern SSMs assume), the matrix exponential simplifies to element-wise exponentiation:
This makes the discretized state transition a simple element-wise operation, avoiding the matrix exponential.
The Dual View: Recurrence and Convolution
A remarkable property of linear SSMs is that they have two equivalent computational modes:
Recurrent mode (for inference): Process tokens one at a time, maintaining a hidden state. Cost per token: , where is the state dimension.
Convolutional mode (for training): Unroll the recurrence into a global convolution. By expanding the recurrence:
This is a causal convolution with kernel:
Using FFT, this convolution costs for a sequence of length , which is far better than the of attention. During training we use the convolutional mode for parallelism; during inference we switch to the recurrent mode for per-step cost.
S4: The Breakthrough
The Structured State Spaces for Sequence Modeling (S4) paper by Gu et al. (2022) made SSMs practical by solving three critical problems:
Problem 1: Long-range dependencies. A naive SSM with random matrix suffers from vanishing gradients over long sequences (the eigenvalues of decay exponentially). S4 solved this with the HiPPO (High-order Polynomial Projection Operator) initialization:
This specific matrix structure has the property that the hidden state optimally compresses the entire input history for by projecting it onto a basis of Legendre polynomials. This is not an arbitrary choice; it is provably optimal for a certain class of approximation problems.
Problem 2: Efficient kernel computation. Computing naively requires matrix-vector products with . S4 showed that when has special structure (diagonal plus low-rank, or purely diagonal in later work like S4D and S5), the kernel can be computed in time using a generating function approach.
Problem 3: Stability. Training a system with a recurrence is inherently unstable if eigenvalues of exceed 1. S4 parameterizes to guarantee stability by constraining eigenvalues to have negative real parts (in continuous time) or magnitude less than 1 (in discrete time).
SSM Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class S4DLayer(nn.Module):
"""
Simplified S4D (diagonal) layer.
Uses diagonal state matrix for efficient computation.
"""
def __init__(self, d_model: int, state_dim: int = 64, dt_min: float = 0.001,
dt_max: float = 0.1):
super().__init__()
self.d_model = d_model
self.state_dim = state_dim
# HiPPO-inspired diagonal initialization (S4D-Lin variant)
# Real part: negative half-integers for stability
A_real = -0.5 * torch.ones(d_model, state_dim)
# Imaginary part: frequencies for oscillatory modes
A_imag = math.pi * torch.arange(state_dim).float().repeat(d_model, 1)
self.A_log_real = nn.Parameter(torch.log(-A_real)) # Log for positivity
self.A_imag = nn.Parameter(A_imag)
# Input and output projections (complex-valued)
self.B_re = nn.Parameter(torch.randn(d_model, state_dim) * 0.5)
self.B_im = nn.Parameter(torch.randn(d_model, state_dim) * 0.5)
self.C_re = nn.Parameter(torch.randn(d_model, state_dim) * 0.5)
self.C_im = nn.Parameter(torch.randn(d_model, state_dim) * 0.5)
# Learnable discretization step size
log_dt = torch.rand(d_model) * (
math.log(dt_max) - math.log(dt_min)
) + math.log(dt_min)
self.log_dt = nn.Parameter(log_dt)
self.D = nn.Parameter(torch.ones(d_model)) # Skip connection
def get_kernel(self, length: int) -> torch.Tensor:
"""Compute the convolution kernel of given length."""
dt = torch.exp(self.log_dt) # (d_model,)
A_real = -torch.exp(self.A_log_real) # Negative for stability
A = torch.complex(A_real, self.A_imag) # (d_model, state_dim)
B = torch.complex(self.B_re, self.B_im)
C = torch.complex(self.C_re, self.C_im)
# Discretize: A_bar = exp(dt * A)
dt_A = dt.unsqueeze(-1) * A # (d_model, state_dim)
A_bar = torch.exp(dt_A)
B_bar = (A_bar - 1.0) / (dt_A + 1e-8) * (dt.unsqueeze(-1) * B)
# Build kernel: K[k] = C @ A_bar^k @ B_bar
# Using geometric series for efficiency
powers = torch.arange(length, device=A.device).unsqueeze(0).unsqueeze(0)
# A_bar^k for each position
A_bar_k = A_bar.unsqueeze(-1) ** powers # (d_model, state_dim, length)
kernel = torch.einsum("dn,dnl->dl", C * B_bar, A_bar_k)
return kernel.real # (d_model, length)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, length, d_model)
Returns:
y: (batch, length, d_model)
"""
batch, length, _ = x.shape
# Compute kernel
kernel = self.get_kernel(length) # (d_model, length)
# Convolutional mode: use FFT for O(L log L) computation
x_t = x.transpose(1, 2) # (batch, d_model, length)
# Pad for causal convolution
x_padded = F.pad(x_t, (kernel.size(-1) - 1, 0))
# FFT convolution
k_f = torch.fft.rfft(kernel, n=x_padded.size(-1))
x_f = torch.fft.rfft(x_padded, n=x_padded.size(-1))
y = torch.fft.irfft(x_f * k_f.unsqueeze(0), n=x_padded.size(-1))
y = y[..., :length] # Trim to original length
# Skip connection
y = y + self.D.unsqueeze(0).unsqueeze(-1) * x_t
return y.transpose(1, 2) # (batch, length, d_model)
Mamba: Selective State Spaces
The Limitation of Linear SSMs
All the models discussed so far (S4, S4D, S5) share a critical limitation: the state transition matrices , , are fixed (independent of the input). This means the model treats every input token identically in terms of how it updates and reads the state.
Consider processing the sentence "The capital of France is Paris." A fixed SSM applies the same dynamics to "The" (uninformative) and "France" (critical). It cannot selectively store or retrieve information based on content. This is fundamentally different from attention, where the query-key mechanism explicitly selects which tokens to attend to.
Mamba's Key Innovation: Input-Dependent Parameters
Mamba (Gu and Dao, 2024) solves this by making , , and functions of the input:
Now the model can:
- Selectively remember: When processing an important token, it can set large, causing to push eigenvalues toward zero, effectively resetting the state and writing new information via .
- Selectively ignore: For unimportant tokens, a small keeps close to the identity, preserving the existing state and barely incorporating the new input.
- Selectively output: The input-dependent controls what information is read from the state at each step.
This selectivity mechanism provides content-aware processing comparable to attention, but within a recurrent framework that maintains complexity.
The Hardware-Aware Parallel Scan
Making parameters input-dependent breaks the convolution view: since varies with , we can no longer express the computation as a fixed convolution. However, we can still parallelize using the parallel scan algorithm.
The recurrence is an instance of a first-order linear recurrence, which can be computed in work and depth using the parallel prefix sum (scan) algorithm. The key insight is that pairs representing can be composed associatively:
Mamba implements this with a custom CUDA kernel that is carefully designed for GPU memory hierarchy:
- Load from HBM to SRAM: The input , discretized parameters , , and are loaded once.
- Scan in SRAM: The parallel scan runs entirely in on-chip SRAM, avoiding repeated HBM reads.
- Write output to HBM: Only the final output is written back.
This is directly analogous to the IO-awareness principle behind FlashAttention, applied to a different computation.
Mamba Architecture
Each Mamba block replaces the attention + FFN pair of a transformer layer:
Input x
|
v
Linear (expand by factor E=2)
|
+--> Branch A: Conv1d -> SiLU -> Selective SSM --+
| |
+--> Branch B: SiLU (gate) ----------------------+
|
v
Element-wise multiply
|
v
Linear (project back)
|
v
Output y
This gated architecture is reminiscent of the GLU family we discussed in Part 4. The gate branch (Branch B) modulates the SSM output (Branch A), providing an additional nonlinear mixing mechanism.
Complete Mamba Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange
class SelectiveSSM(nn.Module):
"""
Selective State Space Model -- the core of Mamba.
Implements input-dependent B, C, and delta with a parallel scan.
"""
def __init__(self, d_inner: int, d_state: int = 16, dt_rank: int = None,
dt_min: float = 0.001, dt_max: float = 0.1, dt_init: str = "random"):
super().__init__()
self.d_inner = d_inner
self.d_state = d_state
self.dt_rank = dt_rank or math.ceil(d_inner / 16)
# Input-dependent projections for B, C, and dt
self.x_proj = nn.Linear(d_inner, self.dt_rank + 2 * d_state, bias=False)
self.dt_proj = nn.Linear(self.dt_rank, d_inner, bias=True)
# Initialize dt bias so that softplus(bias) is in [dt_min, dt_max]
dt = torch.exp(
torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
)
inv_dt = dt + torch.log(-torch.expm1(-dt)) # softplus inverse
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
# State matrix A: initialized as negative real values (log scale)
A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(d_inner, 1)
self.A_log = nn.Parameter(torch.log(A))
# D: skip connection
self.D = nn.Parameter(torch.ones(d_inner))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, length, d_inner) -- post-convolution, post-SiLU
Returns:
y: (batch, length, d_inner)
"""
batch, length, d_inner = x.shape
d_state = self.d_state
# Compute input-dependent B, C, dt
x_dbl = self.x_proj(x) # (batch, length, dt_rank + 2*d_state)
dt, B, C = x_dbl.split([self.dt_rank, d_state, d_state], dim=-1)
dt = self.dt_proj(dt) # (batch, length, d_inner)
dt = F.softplus(dt) # Ensure positive step sizes
# Discretize A
A = -torch.exp(self.A_log.float()) # (d_inner, d_state), negative for stability
# Compute discretized parameters
# dA = exp(dt * A) -- element-wise for diagonal A
dA = torch.exp(
torch.einsum("bld,dn->bldn", dt, A)
) # (batch, length, d_inner, d_state)
# dB = dt * B (simplified ZOH for small dt)
dB_x = torch.einsum("bld,bln->bldn", dt * x, B)
# (batch, length, d_inner, d_state)
# Run the selective scan
y = self.parallel_scan(dA, dB_x, C)
# Skip connection
y = y + x * self.D.unsqueeze(0).unsqueeze(0)
return y
def parallel_scan(
self,
dA: torch.Tensor,
dB_x: torch.Tensor,
C: torch.Tensor,
) -> torch.Tensor:
"""
Implements the parallel scan for the selective SSM recurrence.
Recurrence: h_t = dA_t * h_{t-1} + dB_x_t
Output: y_t = C_t @ h_t
This is a simplified (pure PyTorch) implementation. The real Mamba uses
custom CUDA kernels for hardware-aware execution.
Args:
dA: (batch, length, d_inner, d_state) -- discretized state transition
dB_x: (batch, length, d_inner, d_state) -- discretized input contribution
C: (batch, length, d_state) -- output projection
Returns:
y: (batch, length, d_inner)
"""
batch, length, d_inner, d_state = dA.shape
# Sequential scan (reference implementation)
# In practice, this would use a work-efficient parallel scan on GPU
h = torch.zeros(batch, d_inner, d_state, device=dA.device, dtype=dA.dtype)
ys = []
for t in range(length):
h = dA[:, t] * h + dB_x[:, t] # (batch, d_inner, d_state)
y_t = torch.einsum("bdn,bn->bd", h, C[:, t]) # (batch, d_inner)
ys.append(y_t)
return torch.stack(ys, dim=1) # (batch, length, d_inner)
class MambaBlock(nn.Module):
"""
Full Mamba block: replaces the Attention + FFN pair of a transformer layer.
Architecture:
x -> Linear(expand) -> split into two branches
Branch A: Conv1d -> SiLU -> SelectiveSSM
Branch B: SiLU (gate)
Output: (Branch A * Branch B) -> Linear(project)
"""
def __init__(self, d_model: int, d_state: int = 16, expand: int = 2,
d_conv: int = 4):
super().__init__()
self.d_model = d_model
self.d_inner = d_model * expand
self.d_state = d_state
self.d_conv = d_conv
# Input projection: expand and split into x and gate
self.in_proj = nn.Linear(d_model, 2 * self.d_inner, bias=False)
# Depthwise convolution for local context
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
kernel_size=d_conv,
padding=d_conv - 1,
groups=self.d_inner, # Depthwise
bias=True,
)
# Selective SSM
self.ssm = SelectiveSSM(
d_inner=self.d_inner,
d_state=d_state,
)
# Output projection: compress back to d_model
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
# Normalization
self.norm = nn.RMSNorm(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, length, d_model)
Returns:
(batch, length, d_model)
"""
residual = x
x = self.norm(x)
# Project and split
xz = self.in_proj(x) # (batch, length, 2 * d_inner)
x_branch, z = xz.chunk(2, dim=-1) # Each: (batch, length, d_inner)
# Branch A: conv -> activation -> SSM
x_branch = x_branch.transpose(1, 2) # (batch, d_inner, length)
x_branch = self.conv1d(x_branch)[:, :, :x.size(1)] # Causal: trim future
x_branch = x_branch.transpose(1, 2) # (batch, length, d_inner)
x_branch = F.silu(x_branch)
x_branch = self.ssm(x_branch)
# Branch B: gate
z = F.silu(z)
# Combine and project
y = x_branch * z
y = self.out_proj(y)
return y + residual
class MambaModel(nn.Module):
"""A complete Mamba language model."""
def __init__(self, vocab_size: int, d_model: int, n_layers: int,
d_state: int = 16, expand: int = 2, d_conv: int = 4):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList([
MambaBlock(d_model, d_state, expand, d_conv)
for _ in range(n_layers)
])
self.norm_f = nn.RMSNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
# Weight tying (common in Mamba models)
self.lm_head.weight = self.embedding.weight
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
x = self.embedding(input_ids)
for layer in self.layers:
x = layer(x)
x = self.norm_f(x)
return self.lm_head(x)
Mamba's Performance Profile
Mamba's advantages are most pronounced during inference. Since the SSM state has a fixed size (, typically values), the per-token inference cost is constant regardless of sequence length. Compare this to a transformer's KV cache, which grows linearly with context:
| Metric | Transformer (1.4B) | Mamba (1.4B) |
|---|---|---|
| Training throughput | 1x (baseline) | ~1x (comparable) |
| Inference throughput | 1x (baseline) | ~5x faster |
| Memory per token (inference) | KV cache | fixed state |
| Perplexity (Pile) | ~14.2 | ~14.0 |
| Long-range retrieval | Excellent | Good (not perfect) |
The one area where Mamba has shown weakness compared to transformers of equal size is in-context learning and retrieval tasks that require precise copying from the context. This is because the fixed-size state must compress all past information, making exact recall difficult for very specific tokens.
Linear Attention
Removing the Softmax
Standard attention requires materializing the attention matrix because the softmax operates row-wise, coupling all key-query interactions:
Linear attention (Katharopoulos et al., 2020) replaces softmax with a decomposable kernel:
where is a feature map applied element-wise to queries and keys.
The Associativity Trick
The computational savings come from exploiting matrix multiplication associativity. Consider the unnormalized case:
Standard (left-to-right):
Linear (right-to-left):
By computing first (a matrix), then multiplying each query row by , we avoid constructing the attention matrix entirely. When (which is typically true for long sequences), this is a massive saving.
For causal (autoregressive) attention, we need to ensure token only attends to tokens . This is achieved with a cumulative sum formulation:
This is a recurrence on the state matrix , making it compatible with efficient inference.
Feature Maps
The choice of determines the quality-efficiency tradeoff:
| Feature Map | Definition | Properties |
|---|---|---|
| Identity | Simplest; can produce negative attention weights | |
| ELU+1 | Positive; smooth; simple to implement | |
| ReLU | Positive; sparse; may lose information | |
| Random Fourier | Approximates softmax kernel; unbiased | |
| Performer (FAVOR+) | Positive random features | Unbiased softmax approximation with positivity |
Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class LinearAttention(nn.Module):
"""
Linear attention with causal masking.
Uses the ELU+1 feature map for simplicity and positivity.
"""
def __init__(self, d_model: int, n_heads: int, feature_map: str = "elu"):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.d_model = d_model
self.wq = nn.Linear(d_model, d_model, bias=False)
self.wk = nn.Linear(d_model, d_model, bias=False)
self.wv = nn.Linear(d_model, d_model, bias=False)
self.wo = nn.Linear(d_model, d_model, bias=False)
self.feature_map = feature_map
def phi(self, x: torch.Tensor) -> torch.Tensor:
"""Apply feature map to ensure positive values."""
if self.feature_map == "elu":
return F.elu(x) + 1.0
elif self.feature_map == "relu":
return F.relu(x)
elif self.feature_map == "identity":
return x
else:
raise ValueError(f"Unknown feature map: {self.feature_map}")
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Causal linear attention using the cumulative sum formulation.
Args:
x: (batch, length, d_model)
Returns:
(batch, length, d_model)
"""
batch, length, _ = x.shape
q = self.wq(x).view(batch, length, self.n_heads, self.head_dim)
k = self.wk(x).view(batch, length, self.n_heads, self.head_dim)
v = self.wv(x).view(batch, length, self.n_heads, self.head_dim)
# Apply feature map
q = self.phi(q) # (batch, length, heads, head_dim)
k = self.phi(k)
# Causal linear attention via cumulative sum
# S_t = sum_{i<=t} phi(k_i) @ v_i^T (accumulated KV state)
# z_t = sum_{i<=t} phi(k_i) (accumulated normalizer)
# y_t = (q_t @ S_t) / (q_t @ z_t)
# Compute outer products k_i @ v_i^T for all positions
kv = torch.einsum("blhd,blhe->blhde", k, v) # (B, L, H, D, D)
# Cumulative sum along sequence dimension
S = torch.cumsum(kv, dim=1) # (B, L, H, D, D)
# Normalizer: cumulative sum of keys
z = torch.cumsum(k, dim=1) # (B, L, H, D)
# Compute output
y_num = torch.einsum("blhd,blhde->blhe", q, S) # (B, L, H, D)
y_den = torch.einsum("blhd,blhd->blh", q, z) # (B, L, H)
y_den = y_den.unsqueeze(-1).clamp(min=1e-6) # Prevent division by zero
y = y_num / y_den # (B, L, H, D)
y = y.reshape(batch, length, self.d_model)
return self.wo(y)
Limitations of Linear Attention
Despite the appealing complexity, linear attention has well-documented limitations:
-
No sharp attention patterns. Softmax attention can assign nearly all weight to a single token (approaching a one-hot distribution). Linear attention with positive feature maps produces smoother distributions, making it harder for the model to perform precise retrieval.
-
The dilution problem. The cumulative state accumulates all past KV pairs without discounting. Over very long sequences, the state becomes an average over many tokens, diluting individual contributions. Softmax attention avoids this because it re-normalizes at every step.
-
Quality gap at moderate scale. Empirically, linear attention models tend to underperform softmax attention models of the same size by 1-3 perplexity points, though the gap narrows with scale.
These limitations motivated the development of more sophisticated approaches like RetNet, GLA (Gated Linear Attention), and the hybrid architectures discussed below.
RWKV: Transformer-Level Performance, RNN-Level Efficiency
Design Philosophy
RWKV (Peng et al., 2023) takes a different approach from SSMs. Rather than starting from control theory, it directly reformulates the attention mechanism to be computable as a linear recurrence. The name RWKV comes from the four core elements: Receptance (R), Weight (W), Key (K), and Value (V).
The WKV Mechanism
The core computation in RWKV is the "WKV" (Weighted Key-Value) operator:
where:
- is a learned per-channel decay rate (always positive, so past tokens are exponentially downweighted)
- is a learned bonus for the current token (allowing the model to upweight the current position)
- are key and value projections of the input at position
This can be computed recurrently by maintaining two running sums:
The exponential decay is the key to RWKV's long-range capability: by learning different decay rates per channel, the model can maintain both short-term and long-term memory across different dimensions of the hidden state.
RWKV Block Structure
An RWKV block consists of two sub-blocks:
Input x
|
v
LayerNorm -> Time Mixing (WKV) -> + (residual) -> LayerNorm -> Channel Mixing -> + (residual)
Time mixing handles sequence-level interactions (analogous to attention):
The parameters control a token-shift mechanism: each projection sees a learned interpolation between the current token and the previous one, providing a simple form of local context without convolution.
Channel mixing handles feature-level interactions (analogous to FFN):
The squared ReLU activation provides the nonlinearity, while the receptance gate controls information flow.
RWKV vs. Attention: A Structural Comparison
| Property | Softmax Attention | RWKV |
|---|---|---|
| Sequence interaction | All-pairs () | Recurrent () |
| Memory at inference | KV cache grows with | Fixed-size state |
| Training | Parallel (matmul) | Parallel (custom kernel) |
| Positional awareness | Explicit (RoPE, etc.) | Implicit (decay + token shift) |
| Retrieval precision | Excellent (sharp attention) | Good (exponential decay) |
| Very long range (>100K) | Possible but expensive | Natural and efficient |
RWKV has been scaled to 14B parameters (RWKV-5 "Eagle" and RWKV-6 "Finch"), achieving performance competitive with similarly-sized transformers on standard benchmarks while offering constant-memory inference.
Hybrid Architectures: The Best of Both Worlds
Why Hybrids?
Each architecture excels in different regimes:
- Attention is best for tasks requiring precise retrieval, in-context learning, and complex reasoning over specific context elements.
- SSM/Mamba excels at long-range dependencies, efficient inference, and streaming applications.
- Linear methods offer the best theoretical complexity but sacrifice some expressive power.
Hybrid architectures combine these strengths by using different mechanisms in different layers.
Jamba (AI21 Labs)
Jamba (Lieber et al., 2024) is the first production-scale hybrid Mamba-Transformer model. Its architecture interleaves Mamba layers with attention layers and adds Mixture of Experts:
| Layer Index | Type | MoE |
|---|---|---|
| 0 | Attention | Yes (top-2 of 16 experts) |
| 1-6 | Mamba | Yes |
| 7 | Attention | Yes |
| 8-13 | Mamba | Yes |
| 14 | Attention | Yes |
| ... | ... | ... |
The ratio is approximately 1 attention layer for every 7 Mamba layers. The attention layers handle precise retrieval tasks while the Mamba layers efficiently process the bulk of the sequence context. Jamba achieves:
- 256K context length
- 3x throughput improvement over a pure transformer of similar quality
- Comparable quality to Mixtral 8x7B on standard benchmarks
- Significantly reduced KV cache size (only the few attention layers need a cache)
Mamba-2: Structured State Space Duality
Mamba-2 (Dao and Gu, 2024) establishes a deep theoretical connection between SSMs and attention, which they call the State Space Duality (SSD) framework. The key insight is that the selective SSM computation can be expressed as a specific form of structured (semiseparable) matrix multiplication:
where is a semiseparable matrix determined by the discretized state transition. This matrix has the same structure as a causal attention matrix with specific masking patterns.
This duality enables:
- Larger state dimensions (64 or 128 instead of 16) because the computation maps to efficient matrix multiplies on modern GPUs.
- Direct architectural interleaving with attention layers, since both operate through matrix multiplications of compatible shapes.
- 2-8x speedup over Mamba-1 through better hardware utilization.
Practical Hybrid Design Patterns
Based on published architectures, several hybrid design patterns have emerged:
Pattern 1: Sparse Interleaving (Jamba) Place attention layers at regular intervals (every 6-8 layers). The attention layers act as "global synchronization points" that can perform precise retrieval, while Mamba layers handle local and medium-range processing efficiently.
Pattern 2: First-and-Last (StripedHyena) Use attention in the first and last few layers, with SSM layers in the middle. The intuition is that early layers need to establish token-level representations (which benefits from attention), and final layers need precise output selection, while middle layers primarily propagate and transform features.
Pattern 3: Parallel Hybrid Run attention and SSM in parallel within the same layer and add their outputs. This preserves both pathways without increasing depth. Some Mamba-2 variants explore this pattern.
import torch
import torch.nn as nn
class HybridBlock(nn.Module):
"""
A flexible hybrid block that can be configured as
Mamba-only, Attention-only, or parallel hybrid.
"""
def __init__(self, d_model: int, n_heads: int, d_ff: int,
d_state: int = 16, mode: str = "mamba"):
super().__init__()
self.mode = mode
self.norm = nn.RMSNorm(d_model)
if mode in ("attention", "parallel"):
self.attn = BiasFreeLLaMAAttention(d_model, n_heads, n_heads)
self.attn_norm = nn.RMSNorm(d_model)
if mode in ("mamba", "parallel"):
self.mamba = MambaBlock(d_model, d_state=d_state)
self.ffn = BiasFreeSwiGLU(d_model, d_ff)
self.ffn_norm = nn.RMSNorm(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.mode == "attention":
x = x + self.attn(self.norm(x))
elif self.mode == "mamba":
x = self.mamba(x) # MambaBlock includes its own norm + residual
x = x + self.ffn(self.ffn_norm(x))
return x
elif self.mode == "parallel":
normed = self.norm(x)
x = x + self.attn(normed) + self.mamba.ssm(normed)
x = x + self.ffn(self.ffn_norm(x))
return x
class JambaStyleModel(nn.Module):
"""
Simplified Jamba-style hybrid: interleave attention and Mamba layers.
Every `attn_every` layers uses attention; others use Mamba.
"""
def __init__(self, vocab_size: int, d_model: int, n_layers: int,
n_heads: int, d_ff: int, d_state: int = 16,
attn_every: int = 7):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList()
for i in range(n_layers):
if i % attn_every == 0:
self.layers.append(
HybridBlock(d_model, n_heads, d_ff, d_state, mode="attention")
)
else:
self.layers.append(
HybridBlock(d_model, n_heads, d_ff, d_state, mode="mamba")
)
self.norm_f = nn.RMSNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
x = self.embedding(input_ids)
for layer in self.layers:
x = layer(x)
x = self.norm_f(x)
return self.lm_head(x)
Practical Decision Matrix
When choosing an architecture for a new project, the decision depends on your specific constraints:
| Criterion | Best Choice | Runner-Up | Avoid |
|---|---|---|---|
| Quality at any cost | Transformer (dense attention) | Hybrid (Jamba-style) | Pure linear attention |
| Very long context (>128K) | Mamba / Hybrid | RWKV | Pure transformer |
| Low-latency inference | Mamba | RWKV | Transformer (large KV cache) |
| Streaming / real-time | Mamba, RWKV | Hybrid | Transformer |
| In-context learning | Transformer | Hybrid | Pure SSM |
| Retrieval-heavy tasks | Transformer | Hybrid | Mamba (weaker at exact recall) |
| Memory-constrained deployment | RWKV, Mamba | Quantized transformer | Full-precision transformer |
| Largest proven scale (>100B) | Transformer | Hybrid (early results) | Pure Mamba (not yet proven) |
| Research flexibility | Transformer (most tooling) | Mamba (growing ecosystem) | RWKV (smaller community) |
The practical recommendation in 2025 is:
- Default to transformers if you need the best quality and your context length is under 32K. The ecosystem (frameworks, quantization tools, serving infrastructure) is most mature.
- Use hybrids if you need long context (>64K) with high quality. The combination of a few attention layers with many SSM layers gives the best quality-efficiency tradeoff.
- Use pure Mamba if your primary constraint is inference throughput or memory, and you can accept a small quality tradeoff.
- Use RWKV if you need an open-source, well-tested alternative with RNN-level inference efficiency and an active community.
Conclusion: The Journey Through Transformers
This post concludes the Transformer Deep Dive series. Let us step back and survey what we have covered across all eight parts.
Part 1 introduced the original "Attention Is All You Need" architecture: encoder-decoder structure, multi-head self-attention, sinusoidal position encodings, and the fundamental insight that attention over sequences could replace recurrence entirely.
Part 2 traced the architectural evolution: the shift from encoder-decoder to decoder-only models, Pre-LayerNorm for training stability, and RMSNorm for computational efficiency. These changes seem incremental but were essential for scaling to billions of parameters.
Part 3 examined attention itself in detail: Rotary Position Embeddings (RoPE) that encode relative positions through rotation, Grouped-Query Attention (GQA) that balances quality and KV cache size, and FlashAttention that made exact attention practical through IO-aware computation.
Part 4 dissected the feed-forward network: SwiGLU activation providing better gradient flow, and Mixture of Experts (MoE) enabling dramatic scaling of model capacity without proportional compute increase.
Part 5 covered the training machinery: AdamW optimizer with decoupled weight decay, mixed-precision training with BF16, learning rate schedules, and the data pipeline decisions that determine model quality.
Part 6 focused on inference optimization: KV caching, quantization from FP16 down to INT4, speculative decoding for latency reduction, and continuous batching for throughput.
Part 7 catalogued the small but important details: bias removal, tied vs. untied embeddings, parallel attention and FFN, initialization schemes, and the numerical engineering that makes everything work reliably at scale.
Part 8 (this post) looked beyond attention to State Space Models, Mamba, Linear Attention, RWKV, and hybrid architectures that challenge the assumption that attention is necessary.
The field continues to evolve rapidly. As of this writing, the most exciting developments lie in three directions: hybrid architectures that combine attention's expressiveness with SSM efficiency; hardware-algorithm co-design that optimizes both the silicon and the math simultaneously; and scaling laws for alternative architectures that tell us how these models behave as we push them to hundreds of billions of parameters.
What began with "Attention Is All You Need" in 2017 has grown into a rich landscape of architectural innovations. Attention may not be all you need after all -- but understanding it deeply, along with the alternatives, is essential for anyone building or researching modern AI systems.
References
- Vaswani, A. et al. "Attention Is All You Need." NeurIPS, 2017.
- Gu, A. et al. "Efficiently Modeling Long Sequences with Structured State Spaces." ICLR, 2022.
- Gu, A. et al. "On the Parameterization and Initialization of Diagonal State Space Models." NeurIPS, 2022.
- Gu, A. and Dao, T. "Mamba: Linear-Time Sequence Modeling with Selective State Spaces." arXiv:2312.00752, 2023.
- Dao, T. and Gu, A. "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality." ICML, 2024.
- Katharopoulos, A. et al. "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention." ICML, 2020.
- Choromanski, K. et al. "Rethinking Attention with Performers." ICLR, 2021.
- Peng, B. et al. "RWKV: Reinventing RNNs for the Transformer Era." EMNLP Findings, 2023.
- Peng, B. et al. "Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence." arXiv:2404.05892, 2024.
- Lieber, O. et al. "Jamba: A Hybrid Transformer-Mamba Language Model." arXiv:2403.19887, 2024.
- De, S. et al. "Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models." arXiv:2402.19427, 2024.
- Poli, M. et al. "Hyena Hierarchy: Towards Larger Convolutional Language Models." ICML, 2023.
- Sun, Y. et al. "Retentive Network: A Successor to Transformer for Large Language Models." arXiv:2307.08621, 2023.
- Yang, S. et al. "Gated Linear Attention Transformers with Hardware-Efficient Training." ICML, 2024.
Written by Suchinthaka Wanninayaka
AI/ML Researcher exploring semantic communications, diffusion models, and language model systems. Writing about deep learning from theory to production.
Continue the Series
Transformer Deep Dive: Part 7 - Minor But Important Changes
24 min read
Next Article7 RAG Retrieval Strategies, Benchmarked
12 min read
Related Articles
Responses
No responses yet. Be the first to share your thoughts!