Transformer Deep Dive: Part 4 - FFN Modifications
In the previous posts, we examined the macro architecture (Part 2) and the attention mechanism (Part 3). Now we turn to the other half of every Transformer block: the Feed-Forward Network (FFN). Attention decides which tokens to mix; the FFN decides what to do with the mixed representations. Despite containing roughly two-thirds of the model's parameters, the FFN receives far less attention (no pun intended) in most discussions. That is a mistake -- the evolution from simple ReLU MLPs to gated architectures and Mixture of Experts represents some of the most impactful changes in modern LLM design.
1. The Role of the FFN in Transformers
What Does the FFN Actually Do?
In the original Transformer, the FFN is a simple two-layer MLP applied independently to each token position:
where projects from the model dimension to a larger hidden dimension (typically ), and projects back.
The FFN serves several distinct purposes:
-
Non-linear feature transformation: The attention mechanism is essentially a weighted sum -- a linear operation over values. The FFN introduces the non-linearity required for the network to learn complex functions.
-
Dimensional expansion and compression: The expansion to creates a higher-dimensional space where more complex feature interactions can occur, before compressing back to .
-
Per-token processing: Unlike attention, which mixes information across positions, the FFN operates independently on each token. It transforms what each token represents, rather than where it attends.
-
Knowledge storage: This is perhaps the most fascinating role. Geva et al. (2021) showed that FFN layers act as key-value memories, where the rows of serve as keys (pattern detectors) and the columns of serve as values (associated outputs). When an input pattern matches a key (high activation in the hidden layer), the corresponding value is retrieved and added to the residual stream. This means that much of the factual knowledge in a language model is stored in the FFN weights.
Parameter Distribution
The FFN dominates the parameter count of every Transformer block:
| Component | Parameters per Block | Fraction |
|---|---|---|
| Attention: | ~33% | |
| FFN: | (when ) | ~67% |
For a model like LLaMA 2 7B (, , 32 layers), the FFN parameters total roughly 4.6 billion -- about 67% of the entire model. Understanding and optimizing the FFN is therefore critical for both quality and efficiency.
2. Evolution of Activation Functions
The activation function inside the FFN has evolved considerably from the original ReLU, with each generation addressing specific shortcomings.
ReLU (Rectified Linear Unit)
The original Transformer used ReLU. It is simple, fast, and produces sparse activations (many exact zeros in the hidden layer). However, ReLU has well-known issues:
- Dying ReLU problem: If a neuron's pre-activation is consistently negative, it outputs exactly zero and receives zero gradient. The neuron effectively dies and never recovers.
- Non-smooth at zero: The derivative is discontinuous at , which can cause optimization difficulties.
- Zero-centered output: The outputs are always non-negative, which can introduce systematic biases in the residual stream.
GELU (Gaussian Error Linear Unit)
Hendrycks and Gimpel (2016) proposed GELU as a smooth, probabilistic alternative to ReLU:
where is the cumulative distribution function of the standard normal distribution. This can be approximated as:
or more simply:
Intuition: GELU weights each input by its "probability of being important." Inputs with large positive values (high percentile under the Gaussian) pass through nearly unchanged. Inputs near zero are partially suppressed. Inputs with large negative values are strongly suppressed (but not exactly zeroed, unlike ReLU).
GELU became the standard activation for BERT, GPT-2, GPT-3, and RoBERTa.
SiLU / Swish
Ramachandran et al. (2017) discovered SiLU (Sigmoid Linear Unit), also known as Swish, through automated activation function search:
Properties:
- Self-gated: The input gates itself through the sigmoid function. Large positive inputs pass through (), while large negative inputs are suppressed ().
- Smooth: Infinitely differentiable everywhere, unlike ReLU.
- Non-monotonic: SiLU dips slightly below zero for negative inputs before approaching zero from below. This creates a small "bump" that can help with optimization.
- Unbounded above, bounded below: Similar to ReLU but with smooth behavior near zero.
Activation Function Comparison
import torch
import torch.nn.functional as F
import math
def relu(x: torch.Tensor) -> torch.Tensor:
return torch.maximum(x, torch.zeros_like(x))
def gelu_exact(x: torch.Tensor) -> torch.Tensor:
"""Exact GELU using the error function."""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def gelu_tanh_approx(x: torch.Tensor) -> torch.Tensor:
"""GELU approximation used in GPT-2."""
return 0.5 * x * (1.0 + torch.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * x.pow(3))
))
def silu(x: torch.Tensor) -> torch.Tensor:
"""SiLU/Swish activation. Equivalent to F.silu(x)."""
return x * torch.sigmoid(x)
# Verify equivalence with PyTorch built-ins
x = torch.randn(100)
assert torch.allclose(silu(x), F.silu(x), atol=1e-6)
assert torch.allclose(gelu_exact(x), F.gelu(x), atol=1e-6)
| Activation | Formula | Smooth | Sparse | Non-monotonic | Used By |
|---|---|---|---|---|---|
| ReLU | No | Yes | No | Original Transformer | |
| GELU | Yes | No | No | BERT, GPT-2/3 | |
| SiLU/Swish | Yes | No | Yes | LLaMA, Mistral (via SwiGLU) |
3. Gated Linear Units (GLU) and SwiGLU
The Gating Concept
The most significant change to the FFN in modern LLMs is not just the activation function but the introduction of gating. Gated Linear Units (Dauphin et al., 2017) split the FFN into two parallel branches: one that produces a candidate output, and one that produces a gate controlling how much of that output passes through.
The original GLU formulation:
where is element-wise multiplication, is the sigmoid function, and are two separate weight matrices, and produces a gate in that controls each dimension of the linear output .
GLU Variants
Shazeer (2020) systematically explored replacing the sigmoid gate with different activation functions, leading to a family of GLU variants:
Note that in this formulation, the activation is applied to the first branch, and the second branch is purely linear (no activation). This is the convention used in LLaMA and most modern implementations. The key variants:
| Variant | Gate Activation | Formula |
|---|---|---|
| GLU | Sigmoid | |
| ReGLU | ReLU | |
| GEGLU | GELU | |
| SwiGLU | SiLU/Swish |
Shazeer found that SwiGLU consistently outperformed other variants across multiple benchmarks, and it has been adopted by virtually every major LLM since.
SwiGLU: The Modern FFN
The complete SwiGLU FFN, as used in LLaMA, consists of three weight matrices:
Breaking this down step by step:
- Gate projection: -- project the input and apply the SiLU activation. This produces a gating signal.
- Up projection: -- project the input through a separate linear transformation (no activation).
- Element-wise gating: -- the gate modulates the up-projection, controlling how much information flows through each hidden dimension.
- Down projection: -- project back to the model dimension.
The 2/3 Parameter Adjustment
A critical implementation detail: the gated FFN uses three weight matrices () instead of the standard FFN's two (). To maintain the same total parameter count, the hidden dimension must be reduced by a factor of .
Standard FFN parameters: , where , giving parameters.
SwiGLU FFN parameters: , where we want , so:
In practice, is rounded to the nearest multiple of a convenient number (256 in LLaMA) for hardware efficiency:
For LLaMA 7B (): .
SwiGLU Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class SwiGLU(nn.Module):
"""SwiGLU Feed-Forward Network as used in LLaMA.
Implements: FFN(x) = (SiLU(x @ W_gate) * (x @ W_up)) @ W_down
Uses three weight matrices instead of two, with the hidden
dimension reduced to 2/3 of the standard FFN to maintain
the same parameter count.
"""
def __init__(
self,
d_model: int,
d_ff: int | None = None,
bias: bool = False
):
super().__init__()
if d_ff is None:
# Default: 8/3 * d_model, rounded up to nearest multiple of 256
d_ff = int(8 * d_model / 3)
d_ff = 256 * ((d_ff + 255) // 256)
self.gate_proj = nn.Linear(d_model, d_ff, bias=bias) # W_gate
self.up_proj = nn.Linear(d_model, d_ff, bias=bias) # W_up
self.down_proj = nn.Linear(d_ff, d_model, bias=bias) # W_down
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Gate branch: apply SiLU activation
gate = F.silu(self.gate_proj(x))
# Up branch: linear projection (no activation)
up = self.up_proj(x)
# Element-wise gating, then down-project
return self.down_proj(gate * up)
Why Does Gating Help?
There are several complementary explanations for why gated FFNs outperform standard FFNs:
-
Multiplicative interactions: The element-wise product creates multiplicative interactions between two different projections of the input. This is more expressive than a single projection followed by a point-wise non-linearity.
-
Gradient flow: The gating mechanism creates multiple paths for gradient flow. Gradients can flow through the gate branch, the up branch, or both. This is similar in spirit to the gating in LSTMs and GRUs, which was designed specifically to alleviate vanishing gradients.
-
Adaptive sparsity: The SiLU gate can suppress certain hidden dimensions for certain inputs, creating an input-dependent sparsity pattern. Unlike ReLU's hard zero, SiLU provides a soft gating signal, allowing the model to smoothly interpolate between "pass" and "block."
-
Feature selection: The gate branch can be interpreted as learning which hidden dimensions are relevant for the current input, while the up branch computes the values for those dimensions. This separation of concerns may improve optimization.
Naming Conventions Across Codebases
Different codebases use different names for the same three matrices, which can be confusing:
| Codebase | Gate () | Up () | Down () |
|---|---|---|---|
| LLaMA official | w1 | w3 | w2 |
| HuggingFace | gate_proj | up_proj | down_proj |
| GPT-NeoX | dense_h_to_4h (fused) | -- | dense_4h_to_h |
The LLaMA naming (w1, w3, w2) can be particularly confusing because w3 is the "up" projection and w2 is the "down" projection, which feels inverted. The HuggingFace naming is more intuitive.
4. Mixture of Experts (MoE)
The Dense Model Scaling Wall
The FFN contains the majority of a model's parameters, and scaling dense models means every token must pass through every parameter. For a 70B parameter model generating tokens at 30 tokens/second, each token requires roughly 140 TFLOPS of computation. Scaling to 1 trillion parameters would require roughly 2 PFLOPS per token -- far beyond what a single accelerator can deliver in real time.
Mixture of Experts offers an elegant solution: scale the total parameter count without proportionally increasing the per-token computation.
MoE Architecture
In an MoE Transformer, the FFN in some or all layers is replaced with multiple "expert" FFNs plus a routing mechanism. Each token is directed to only a subset of experts, so the computation per token remains manageable while the total model capacity grows with the number of experts.
Formally, the MoE layer output is:
where is the -th expert network (a SwiGLU FFN), is the gating (routing) function, and is the total number of experts. In practice, is sparse -- only the top- values are non-zero:
The function keeps the largest logits and sets the rest to before applying softmax.
Efficiency Analysis
The key insight is the ratio between total parameters and active parameters:
| Model | Total Params | Active Params per Token | Experts | Top-K | Efficiency Ratio |
|---|---|---|---|---|---|
| LLaMA 2 70B (dense) | 70B | 70B | 1 | 1 | 1x |
| Mixtral 8x7B | 46.7B | 12.9B | 8 | 2 | 3.6x |
| Mixtral 8x22B | 141B | 39B | 8 | 2 | 3.6x |
| GPT-4 (rumored) | ~1.8T | ~220B | 16 | 2 | ~8x |
Mixtral 8x7B has 46.7B total parameters but only activates 12.9B per token. This means it achieves quality comparable to or better than a 70B dense model (due to more total capacity) while being as fast to run as a 13B dense model (due to fewer active parameters).
The Routing Problem
Routing is the most delicate part of the MoE architecture. The router must learn to distribute tokens to experts in a way that is both effective (each expert specializes) and balanced (no expert is overloaded or underutilized).
Router mechanism: The router is typically a simple linear layer that takes a token representation and produces logits over experts:
The top- logits are selected, and a softmax is applied over just those values to get the routing weights:
Load Balancing
Without explicit encouragement, the router tends to converge to degenerate solutions where most tokens are routed to a few "popular" experts while others are ignored. This wastes model capacity and can cause training instabilities.
The standard solution is an auxiliary load-balancing loss (Fedus et al., 2022):
where:
- is the fraction of tokens routed to expert
- is the average router probability assigned to expert (before top-k)
- is a hyperparameter controlling the strength of the balancing signal (typically 0.01)
- is the number of experts
This loss encourages the product to be uniform across experts. By the AM-GM inequality, this product is minimized when for all , which corresponds to perfect balance.
Expert Capacity and Token Dropping
In practice, experts have a fixed capacity -- the maximum number of tokens they can process per batch. If more tokens are routed to an expert than its capacity allows, the excess tokens are dropped (their output is set to zero, with only the residual connection surviving). This prevents any single expert from becoming a computational bottleneck.
where is the total number of tokens, is top-k, is the number of experts, and is the capacity factor (typically 1.0-1.5). A capacity factor of 1.0 means exactly the expected number of tokens; higher values provide a buffer for imbalanced routing.
MoE Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoELayer(nn.Module):
"""Mixture of Experts layer with top-k routing and load balancing.
Replaces the standard FFN in a Transformer block. Each expert
is a SwiGLU FFN, and the router selects the top-k experts per token.
"""
def __init__(
self,
d_model: int,
d_ff: int,
n_experts: int = 8,
top_k: int = 2,
capacity_factor: float = 1.25
):
super().__init__()
self.n_experts = n_experts
self.top_k = top_k
self.capacity_factor = capacity_factor
# Router: linear projection to expert logits
self.router = nn.Linear(d_model, n_experts, bias=False)
# Expert FFNs (each is a SwiGLU)
self.experts = nn.ModuleList([
SwiGLU(d_model, d_ff) for _ in range(n_experts)
])
def forward(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass with top-k routing.
Args:
x: Input tensor of shape (batch, seq_len, d_model).
Returns:
output: MoE output, same shape as input.
aux_loss: Load-balancing auxiliary loss (scalar).
"""
B, T, D = x.shape
x_flat = x.view(-1, D) # (B*T, D)
num_tokens = x_flat.shape[0]
# Compute router logits and probabilities
router_logits = self.router(x_flat) # (B*T, n_experts)
router_probs = F.softmax(router_logits, dim=-1)
# Select top-k experts per token
top_k_weights, top_k_indices = torch.topk(
router_probs, self.top_k, dim=-1
)
# Renormalize weights over selected experts
top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
# Compute auxiliary load-balancing loss
aux_loss = self._load_balancing_loss(router_probs, top_k_indices)
# Compute expert outputs
output = torch.zeros_like(x_flat)
for i in range(self.n_experts):
# Find tokens routed to this expert (across any of the top-k slots)
expert_mask = (top_k_indices == i) # (num_tokens, top_k)
token_mask = expert_mask.any(dim=-1) # (num_tokens,)
if not token_mask.any():
continue
# Get the tokens for this expert
expert_input = x_flat[token_mask] # (n_selected, D)
# Compute expert output
expert_output = self.experts[i](expert_input) # (n_selected, D)
# Get the routing weight for this expert
# For each selected token, sum the weights assigned to this expert
# (a token could select the same expert in multiple top-k slots,
# though this is rare)
weights = (expert_mask[token_mask] * top_k_weights[token_mask])
weight = weights.sum(dim=-1, keepdim=True) # (n_selected, 1)
output[token_mask] += weight * expert_output
return output.view(B, T, D), aux_loss
def _load_balancing_loss(
self,
router_probs: torch.Tensor,
top_k_indices: torch.Tensor
) -> torch.Tensor:
"""Compute the auxiliary load-balancing loss.
Encourages uniform distribution of tokens across experts.
"""
num_tokens = router_probs.shape[0]
# f_i: fraction of tokens routed to each expert
# Count how many tokens have expert i in their top-k
expert_counts = torch.zeros(
self.n_experts, device=router_probs.device
)
for k in range(self.top_k):
expert_counts.scatter_add_(
0, top_k_indices[:, k],
torch.ones(num_tokens, device=router_probs.device)
)
f = expert_counts / num_tokens # (n_experts,)
# P_i: average router probability for each expert
P = router_probs.mean(dim=0) # (n_experts,)
# Auxiliary loss: N * sum(f_i * P_i)
aux_loss = self.n_experts * (f * P).sum()
return aux_loss
Expert Parallelism
MoE introduces a new dimension of parallelism in distributed training. Since different experts are independent modules, they can be placed on different devices:
| Parallelism Strategy | What Is Distributed | Communication Pattern |
|---|---|---|
| Data Parallelism | Same model, different batches | AllReduce on gradients |
| Tensor Parallelism | Weight matrices split across devices | AllReduce on activations |
| Pipeline Parallelism | Different layers on different devices | Point-to-point between stages |
| Expert Parallelism | Different experts on different devices | All-to-All for token routing |
Expert parallelism uses All-to-All communication: tokens are dispatched from their "home" device to whichever device hosts their selected expert, processed there, and then sent back. The communication cost is proportional to the number of tokens that cross device boundaries, which depends on the routing distribution.
In Mixtral, 8 experts can be placed on 8 GPUs with one expert per GPU. Each GPU holds the full attention weights (shared across all tokens) plus one expert FFN. During the FFN computation:
- Each GPU computes router logits for its local tokens.
- All-to-All communication dispatches each token to the GPU(s) holding its selected expert(s).
- Each GPU processes the tokens assigned to its local expert.
- All-to-All communication returns the expert outputs to the originating GPUs.
- Each GPU combines the expert outputs with the routing weights.
MoE Design Decisions
Several design choices significantly impact MoE performance:
Number of experts. More experts means more total capacity for the same active parameter count, but also more routing complexity and potential for load imbalance. Mixtral uses 8 experts; some research models have explored 64 or more.
Top-k value. is most efficient but can be unstable during training. is the most common choice, providing redundancy in routing while maintaining efficiency. Higher values approach dense computation and reduce the efficiency benefit.
Which layers use MoE. Not every layer needs to be an MoE layer. Some architectures apply MoE only to every other layer, or only to deeper layers. The Mixtral architecture applies MoE to every layer.
Expert granularity. An alternative to having a few large experts is to have many small experts with higher top-k. Deepseek-MoE explores "fine-grained experts" where each expert is smaller but more experts are selected, potentially leading to more specialized routing.
Shared Expert Variants
Deepseek-V2 introduced an interesting variant: a shared expert that processes every token, combined with routed experts that are selected per-token:
The shared expert captures common patterns that are useful for all tokens, while the routed experts handle specialized knowledge. This can improve the stability of training and reduce the impact of routing errors.
5. Putting It All Together: The Modern FFN Block
A complete transformer block in a modern dense model (like LLaMA 3) uses:
class ModernTransformerBlock(nn.Module):
"""Complete transformer block following the LLaMA 3 architecture.
Combines:
- Pre-LN with RMSNorm (Part 2)
- Grouped-Query Attention with RoPE (Part 3)
- SwiGLU FFN (Part 4)
"""
def __init__(
self,
d_model: int = 4096,
n_heads: int = 32,
n_kv_heads: int = 8,
d_ff: int | None = None,
max_seq_len: int = 8192,
rope_theta: float = 500000.0
):
super().__init__()
# RMSNorm for Pre-LN
self.attn_norm = RMSNorm(d_model)
self.ffn_norm = RMSNorm(d_model)
# GQA with RoPE
self.attention = GroupedQueryAttention(
d_model=d_model,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
max_seq_len=max_seq_len,
rope_theta=rope_theta
)
# SwiGLU FFN
self.ffn = SwiGLU(d_model, d_ff)
def forward(
self,
x: torch.Tensor,
start_pos: int = 0,
mask: torch.Tensor | None = None
) -> torch.Tensor:
# Pre-RMSNorm + GQA + Residual
h = self.attn_norm(x)
h = self.attention(h, start_pos=start_pos, mask=mask)
x = x + h
# Pre-RMSNorm + SwiGLU + Residual
h = self.ffn_norm(x)
h = self.ffn(h)
x = x + h
return x
And for an MoE model (like Mixtral), the only change is replacing the SwiGLU with an MoELayer:
class MoETransformerBlock(nn.Module):
"""Transformer block with Mixture of Experts FFN (Mixtral-style)."""
def __init__(
self,
d_model: int = 4096,
n_heads: int = 32,
n_kv_heads: int = 8,
d_ff: int = 14336,
n_experts: int = 8,
top_k: int = 2,
max_seq_len: int = 32768
):
super().__init__()
self.attn_norm = RMSNorm(d_model)
self.ffn_norm = RMSNorm(d_model)
self.attention = GroupedQueryAttention(
d_model=d_model,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
max_seq_len=max_seq_len
)
# MoE replaces the single FFN
self.moe = MoELayer(
d_model=d_model,
d_ff=d_ff,
n_experts=n_experts,
top_k=top_k
)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
# Attention (same as dense)
h = self.attn_norm(x)
h = self.attention(h, mask=mask)
x = x + h
# MoE FFN
h = self.ffn_norm(x)
h, aux_loss = self.moe(h)
x = x + h
return x, aux_loss
Summary: Evolution of the FFN
| Era | FFN Type | Activation | Weight Matrices | Hidden Dim | Used By |
|---|---|---|---|---|---|
| 2017 | Standard MLP | ReLU | 2 () | Original Transformer | |
| 2018-2020 | Standard MLP | GELU | 2 () | BERT, GPT-2, GPT-3 | |
| 2022+ | Gated (SwiGLU) | SiLU/Swish | 3 () | LLaMA, Mistral, PaLM | |
| 2023+ | MoE + SwiGLU | SiLU/Swish | + router | per expert | Mixtral, GPT-4 (rumored) |
The progression from simple ReLU MLPs to gated SwiGLU networks to Mixture of Experts represents a steady refinement: each generation offers more expressive feature transformations, better gradient flow, or more efficient scaling. The SwiGLU FFN, in particular, has become as fundamental to modern LLM design as the attention mechanism itself.
In the next post, we will explore Part 5: Training Improvements -- the optimization techniques that make training these architectures feasible at scale, including AdamW, learning rate scheduling, mixed-precision training, and gradient checkpointing.
References
- Vaswani, A., Shazeer, N., Parmar, N., et al. (2017). "Attention Is All You Need." NeurIPS 2017.
- Hendrycks, D. and Gimpel, K. (2016). "Gaussian Error Linear Units (GELUs)." arXiv:1606.08415.
- Ramachandran, P., Zoph, B., Le, Q. V. (2017). "Searching for Activation Functions." arXiv:1710.05941.
- Dauphin, Y. N., Fan, A., Auli, M., Grangier, D. (2017). "Language Modeling with Gated Convolutional Networks." ICML 2017.
- Shazeer, N. (2020). "GLU Variants Improve Transformer." arXiv:2002.05202.
- Geva, M., Schuster, R., Berant, J., Levy, O. (2021). "Transformer Feed-Forward Layers Are Key-Value Memories." EMNLP 2021.
- Fedus, W., Zoph, B., Shazeer, N. (2022). "Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity." JMLR 2022.
- Jiang, A. Q., Sablayrolles, A., Roux, A., et al. (2024). "Mixtral of Experts." arXiv:2401.04088.
- Touvron, H., Lavril, T., Izacard, G., et al. (2023). "LLaMA: Open and Efficient Foundation Language Models." arXiv:2302.13971.
- Dai, D., Deng, C., Zhao, C., et al. (2024). "DeepSeekMoE: Towards Ultimate Expert Specialization in Mixture-of-Experts Language Models." arXiv:2401.06066.
Written by Suchinthaka Wanninayaka
AI/ML Researcher exploring semantic communications, diffusion models, and language model systems. Writing about deep learning from theory to production.
Continue the Series
Transformer Deep Dive: Part 3 - Attention Modifications
22 min read
Next ArticleTransformer Deep Dive: Part 5 - Training Improvements
25 min read
Related Articles
Responses
No responses yet. Be the first to share your thoughts!