In the previous posts, we examined the macro architecture (Part 2) and the attention mechanism (Part 3). Now we turn to the other half of every Transformer block: the Feed-Forward Network (FFN). Attention decides which tokens to mix; the FFN decides what to do with the mixed representations. Despite containing roughly two-thirds of the model's parameters, the FFN receives far less attention (no pun intended) in most discussions. That is a mistake -- the evolution from simple ReLU MLPs to gated architectures and Mixture of Experts represents some of the most impactful changes in modern LLM design.

1. The Role of the FFN in Transformers

What Does the FFN Actually Do?

In the original Transformer, the FFN is a simple two-layer MLP applied independently to each token position:

FFN(x)=ReLU(xW1+b1)W2+b2\text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2

where W1Rd×dffW_1 \in \mathbb{R}^{d \times d_{ff}} projects from the model dimension dd to a larger hidden dimension dffd_{ff} (typically 4d4d), and W2Rdff×dW_2 \in \mathbb{R}^{d_{ff} \times d} projects back.

The FFN serves several distinct purposes:

  1. Non-linear feature transformation: The attention mechanism is essentially a weighted sum -- a linear operation over values. The FFN introduces the non-linearity required for the network to learn complex functions.

  2. Dimensional expansion and compression: The expansion to dff=4dd_{ff} = 4d creates a higher-dimensional space where more complex feature interactions can occur, before compressing back to dd.

  3. Per-token processing: Unlike attention, which mixes information across positions, the FFN operates independently on each token. It transforms what each token represents, rather than where it attends.

  4. Knowledge storage: This is perhaps the most fascinating role. Geva et al. (2021) showed that FFN layers act as key-value memories, where the rows of W1W_1 serve as keys (pattern detectors) and the columns of W2W_2 serve as values (associated outputs). When an input pattern matches a key (high activation in the hidden layer), the corresponding value is retrieved and added to the residual stream. This means that much of the factual knowledge in a language model is stored in the FFN weights.

Parameter Distribution

The FFN dominates the parameter count of every Transformer block:

ComponentParameters per BlockFraction
Attention: WQ,WK,WV,WOW_Q, W_K, W_V, W_O4×d24 \times d^2~33%
FFN: W1,W2W_1, W_22×d×dff=8d22 \times d \times d_{ff} = 8d^2 (when dff=4dd_{ff} = 4d)~67%

For a model like LLaMA 2 7B (d=4096d = 4096, dff=11008d_{ff} = 11008, 32 layers), the FFN parameters total roughly 4.6 billion -- about 67% of the entire model. Understanding and optimizing the FFN is therefore critical for both quality and efficiency.

2. Evolution of Activation Functions

The activation function inside the FFN has evolved considerably from the original ReLU, with each generation addressing specific shortcomings.

ReLU (Rectified Linear Unit)

ReLU(x)=max(0,x)\text{ReLU}(x) = \max(0, x)

The original Transformer used ReLU. It is simple, fast, and produces sparse activations (many exact zeros in the hidden layer). However, ReLU has well-known issues:

  • Dying ReLU problem: If a neuron's pre-activation is consistently negative, it outputs exactly zero and receives zero gradient. The neuron effectively dies and never recovers.
  • Non-smooth at zero: The derivative is discontinuous at x=0x = 0, which can cause optimization difficulties.
  • Zero-centered output: The outputs are always non-negative, which can introduce systematic biases in the residual stream.

GELU (Gaussian Error Linear Unit)

Hendrycks and Gimpel (2016) proposed GELU as a smooth, probabilistic alternative to ReLU:

GELU(x)=xΦ(x)\text{GELU}(x) = x \cdot \Phi(x)

where Φ(x)\Phi(x) is the cumulative distribution function of the standard normal distribution. This can be approximated as:

GELU(x)0.5x(1+tanh[2π(x+0.044715x3)])\text{GELU}(x) \approx 0.5x\left(1 + \tanh\left[\sqrt{\frac{2}{\pi}}\left(x + 0.044715x^3\right)\right]\right)

or more simply:

GELU(x)xσ(1.702x)\text{GELU}(x) \approx x \cdot \sigma(1.702x)

Intuition: GELU weights each input by its "probability of being important." Inputs with large positive values (high percentile under the Gaussian) pass through nearly unchanged. Inputs near zero are partially suppressed. Inputs with large negative values are strongly suppressed (but not exactly zeroed, unlike ReLU).

GELU became the standard activation for BERT, GPT-2, GPT-3, and RoBERTa.

SiLU / Swish

Ramachandran et al. (2017) discovered SiLU (Sigmoid Linear Unit), also known as Swish, through automated activation function search:

SiLU(x)=xσ(x)=x1+ex\text{SiLU}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}}

Properties:

  • Self-gated: The input gates itself through the sigmoid function. Large positive inputs pass through (σ(x)1\sigma(x) \to 1), while large negative inputs are suppressed (σ(x)0\sigma(x) \to 0).
  • Smooth: Infinitely differentiable everywhere, unlike ReLU.
  • Non-monotonic: SiLU dips slightly below zero for negative inputs before approaching zero from below. This creates a small "bump" that can help with optimization.
  • Unbounded above, bounded below: Similar to ReLU but with smooth behavior near zero.

Comparison of activation functions: ReLU (piecewise linear, zero for negative inputs), GELU (smooth approximation of ReLU with probabilistic gating), and SiLU/Swish (self-gated, slightly non-monotonic near zero).

Activation Function Comparison

import torch
import torch.nn.functional as F
import math


def relu(x: torch.Tensor) -> torch.Tensor:
    return torch.maximum(x, torch.zeros_like(x))


def gelu_exact(x: torch.Tensor) -> torch.Tensor:
    """Exact GELU using the error function."""
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


def gelu_tanh_approx(x: torch.Tensor) -> torch.Tensor:
    """GELU approximation used in GPT-2."""
    return 0.5 * x * (1.0 + torch.tanh(
        math.sqrt(2.0 / math.pi) * (x + 0.044715 * x.pow(3))
    ))


def silu(x: torch.Tensor) -> torch.Tensor:
    """SiLU/Swish activation. Equivalent to F.silu(x)."""
    return x * torch.sigmoid(x)


# Verify equivalence with PyTorch built-ins
x = torch.randn(100)
assert torch.allclose(silu(x), F.silu(x), atol=1e-6)
assert torch.allclose(gelu_exact(x), F.gelu(x), atol=1e-6)
ActivationFormulaSmoothSparseNon-monotonicUsed By
ReLUmax(0,x)\max(0, x)NoYesNoOriginal Transformer
GELUxΦ(x)x\Phi(x)YesNoNoBERT, GPT-2/3
SiLU/Swishxσ(x)x\sigma(x)YesNoYesLLaMA, Mistral (via SwiGLU)

3. Gated Linear Units (GLU) and SwiGLU

The Gating Concept

The most significant change to the FFN in modern LLMs is not just the activation function but the introduction of gating. Gated Linear Units (Dauphin et al., 2017) split the FFN into two parallel branches: one that produces a candidate output, and one that produces a gate controlling how much of that output passes through.

The original GLU formulation:

GLU(x)=(xW1+b1)σ(xV+c)\text{GLU}(x) = (xW_1 + b_1) \otimes \sigma(xV + c)

where \otimes is element-wise multiplication, σ\sigma is the sigmoid function, W1W_1 and VV are two separate weight matrices, and σ(xV+c)\sigma(xV + c) produces a gate in [0,1][0, 1] that controls each dimension of the linear output xW1xW_1.

GLU Variants

Shazeer (2020) systematically explored replacing the sigmoid gate with different activation functions, leading to a family of GLU variants:

GLU Variant(x)=(Activation(xW1))(xV)\text{GLU Variant}(x) = (\text{Activation}(xW_1)) \otimes (xV)

Note that in this formulation, the activation is applied to the first branch, and the second branch is purely linear (no activation). This is the convention used in LLaMA and most modern implementations. The key variants:

VariantGate ActivationFormula
GLUSigmoid σ\sigmaσ(xW1)(xV)\sigma(xW_1) \otimes (xV)
ReGLUReLUReLU(xW1)(xV)\text{ReLU}(xW_1) \otimes (xV)
GEGLUGELUGELU(xW1)(xV)\text{GELU}(xW_1) \otimes (xV)
SwiGLUSiLU/SwishSiLU(xW1)(xV)\text{SiLU}(xW_1) \otimes (xV)

Shazeer found that SwiGLU consistently outperformed other variants across multiple benchmarks, and it has been adopted by virtually every major LLM since.

SwiGLU: The Modern FFN

The complete SwiGLU FFN, as used in LLaMA, consists of three weight matrices:

FFNSwiGLU(x)=(SiLU(xWgate)(xWup))Wdown\text{FFN}_{SwiGLU}(x) = \big(\text{SiLU}(xW_{\text{gate}}) \otimes (xW_{\text{up}})\big) W_{\text{down}}

Breaking this down step by step:

  1. Gate projection: g=SiLU(xWgate)g = \text{SiLU}(xW_{\text{gate}}) -- project the input and apply the SiLU activation. This produces a gating signal.
  2. Up projection: u=xWupu = xW_{\text{up}} -- project the input through a separate linear transformation (no activation).
  3. Element-wise gating: h=guh = g \otimes u -- the gate modulates the up-projection, controlling how much information flows through each hidden dimension.
  4. Down projection: y=hWdowny = hW_{\text{down}} -- project back to the model dimension.

The 2/3 Parameter Adjustment

A critical implementation detail: the gated FFN uses three weight matrices (Wgate,Wup,WdownW_{\text{gate}}, W_{\text{up}}, W_{\text{down}}) instead of the standard FFN's two (W1,W2W_1, W_2). To maintain the same total parameter count, the hidden dimension dffd_{ff} must be reduced by a factor of 23\frac{2}{3}.

Standard FFN parameters: 2×d×dff2 \times d \times d_{ff}, where dff=4dd_{ff} = 4d, giving 8d28d^2 parameters.

SwiGLU FFN parameters: 3×d×dff3 \times d \times d_{ff}', where we want 3×d×dff=8d23 \times d \times d_{ff}' = 8d^2, so:

dff=8d3=23×4dd_{ff}' = \frac{8d}{3} = \frac{2}{3} \times 4d

In practice, dffd_{ff}' is rounded to the nearest multiple of a convenient number (256 in LLaMA) for hardware efficiency:

dff=256×8d/3256d_{ff} = 256 \times \left\lceil \frac{8d/3}{256} \right\rceil

For LLaMA 7B (d=4096d = 4096): dff=256×10922.67/256=256×43=11008d_{ff} = 256 \times \lceil 10922.67 / 256 \rceil = 256 \times 43 = 11008.

SwiGLU Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F


class SwiGLU(nn.Module):
    """SwiGLU Feed-Forward Network as used in LLaMA.

    Implements: FFN(x) = (SiLU(x @ W_gate) * (x @ W_up)) @ W_down

    Uses three weight matrices instead of two, with the hidden
    dimension reduced to 2/3 of the standard FFN to maintain
    the same parameter count.
    """

    def __init__(
        self,
        d_model: int,
        d_ff: int | None = None,
        bias: bool = False
    ):
        super().__init__()
        if d_ff is None:
            # Default: 8/3 * d_model, rounded up to nearest multiple of 256
            d_ff = int(8 * d_model / 3)
            d_ff = 256 * ((d_ff + 255) // 256)

        self.gate_proj = nn.Linear(d_model, d_ff, bias=bias)  # W_gate
        self.up_proj = nn.Linear(d_model, d_ff, bias=bias)    # W_up
        self.down_proj = nn.Linear(d_ff, d_model, bias=bias)  # W_down

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Gate branch: apply SiLU activation
        gate = F.silu(self.gate_proj(x))
        # Up branch: linear projection (no activation)
        up = self.up_proj(x)
        # Element-wise gating, then down-project
        return self.down_proj(gate * up)

Why Does Gating Help?

There are several complementary explanations for why gated FFNs outperform standard FFNs:

  1. Multiplicative interactions: The element-wise product gug \otimes u creates multiplicative interactions between two different projections of the input. This is more expressive than a single projection followed by a point-wise non-linearity.

  2. Gradient flow: The gating mechanism creates multiple paths for gradient flow. Gradients can flow through the gate branch, the up branch, or both. This is similar in spirit to the gating in LSTMs and GRUs, which was designed specifically to alleviate vanishing gradients.

  3. Adaptive sparsity: The SiLU gate can suppress certain hidden dimensions for certain inputs, creating an input-dependent sparsity pattern. Unlike ReLU's hard zero, SiLU provides a soft gating signal, allowing the model to smoothly interpolate between "pass" and "block."

  4. Feature selection: The gate branch can be interpreted as learning which hidden dimensions are relevant for the current input, while the up branch computes the values for those dimensions. This separation of concerns may improve optimization.

Naming Conventions Across Codebases

Different codebases use different names for the same three matrices, which can be confusing:

CodebaseGate (WgateW_{\text{gate}})Up (WupW_{\text{up}})Down (WdownW_{\text{down}})
LLaMA officialw1w3w2
HuggingFacegate_projup_projdown_proj
GPT-NeoXdense_h_to_4h (fused)--dense_4h_to_h

The LLaMA naming (w1, w3, w2) can be particularly confusing because w3 is the "up" projection and w2 is the "down" projection, which feels inverted. The HuggingFace naming is more intuitive.

4. Mixture of Experts (MoE)

The Dense Model Scaling Wall

The FFN contains the majority of a model's parameters, and scaling dense models means every token must pass through every parameter. For a 70B parameter model generating tokens at 30 tokens/second, each token requires roughly 140 TFLOPS of computation. Scaling to 1 trillion parameters would require roughly 2 PFLOPS per token -- far beyond what a single accelerator can deliver in real time.

Mixture of Experts offers an elegant solution: scale the total parameter count without proportionally increasing the per-token computation.

MoE Architecture

In an MoE Transformer, the FFN in some or all layers is replaced with multiple "expert" FFNs plus a routing mechanism. Each token is directed to only a subset of experts, so the computation per token remains manageable while the total model capacity grows with the number of experts.

Mixture of Experts architecture. A router network receives the input token representation and produces a probability distribution over experts. The top-K experts (typically K=1 or K=2) are selected, their outputs are computed, and the final output is a weighted sum based on the router probabilities.

Formally, the MoE layer output is:

MoE(x)=i=1NG(x)iEi(x)\text{MoE}(x) = \sum_{i=1}^{N} G(x)_i \cdot E_i(x)

where EiE_i is the ii-th expert network (a SwiGLU FFN), G(x)G(x) is the gating (routing) function, and NN is the total number of experts. In practice, G(x)G(x) is sparse -- only the top-kk values are non-zero:

G(x)=Softmax(TopK(xWg,k))G(x) = \text{Softmax}\big(\text{TopK}(x \cdot W_g, k)\big)

The TopK\text{TopK} function keeps the kk largest logits and sets the rest to -\infty before applying softmax.

Efficiency Analysis

The key insight is the ratio between total parameters and active parameters:

ModelTotal ParamsActive Params per TokenExpertsTop-KEfficiency Ratio
LLaMA 2 70B (dense)70B70B111x
Mixtral 8x7B46.7B12.9B823.6x
Mixtral 8x22B141B39B823.6x
GPT-4 (rumored)~1.8T~220B162~8x

Mixtral 8x7B has 46.7B total parameters but only activates 12.9B per token. This means it achieves quality comparable to or better than a 70B dense model (due to more total capacity) while being as fast to run as a 13B dense model (due to fewer active parameters).

The Routing Problem

Routing is the most delicate part of the MoE architecture. The router must learn to distribute tokens to experts in a way that is both effective (each expert specializes) and balanced (no expert is overloaded or underutilized).

Router mechanism: The router is typically a simple linear layer that takes a token representation and produces logits over experts:

router_logits=xWg,WgRd×N\text{router\_logits} = x \cdot W_g, \quad W_g \in \mathbb{R}^{d \times N}

The top-kk logits are selected, and a softmax is applied over just those kk values to get the routing weights:

wi=erijTopKerj,iTopKw_i = \frac{e^{r_i}}{\sum_{j \in \text{TopK}} e^{r_j}}, \quad i \in \text{TopK}

Load Balancing

Without explicit encouragement, the router tends to converge to degenerate solutions where most tokens are routed to a few "popular" experts while others are ignored. This wastes model capacity and can cause training instabilities.

The standard solution is an auxiliary load-balancing loss (Fedus et al., 2022):

Laux=αNi=1NfiPi\mathcal{L}_{\text{aux}} = \alpha \cdot N \sum_{i=1}^{N} f_i \cdot P_i

where:

  • fi=1Tt=1T1[token t routed to expert i]f_i = \frac{1}{T}\sum_{t=1}^{T} \mathbb{1}[\text{token } t \text{ routed to expert } i] is the fraction of tokens routed to expert ii
  • Pi=1Tt=1TG(xt)iP_i = \frac{1}{T}\sum_{t=1}^{T} G(x_t)_i is the average router probability assigned to expert ii (before top-k)
  • α\alpha is a hyperparameter controlling the strength of the balancing signal (typically 0.01)
  • NN is the number of experts

This loss encourages the product fiPif_i \cdot P_i to be uniform across experts. By the AM-GM inequality, this product is minimized when fi=Pi=1/Nf_i = P_i = 1/N for all ii, which corresponds to perfect balance.

Expert Capacity and Token Dropping

In practice, experts have a fixed capacity -- the maximum number of tokens they can process per batch. If more tokens are routed to an expert than its capacity allows, the excess tokens are dropped (their output is set to zero, with only the residual connection surviving). This prevents any single expert from becoming a computational bottleneck.

capacity=TkNCf\text{capacity} = \left\lceil \frac{T \cdot k}{N} \cdot C_f \right\rceil

where TT is the total number of tokens, kk is top-k, NN is the number of experts, and CfC_f is the capacity factor (typically 1.0-1.5). A capacity factor of 1.0 means exactly the expected number of tokens; higher values provide a buffer for imbalanced routing.

MoE Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F


class MoELayer(nn.Module):
    """Mixture of Experts layer with top-k routing and load balancing.

    Replaces the standard FFN in a Transformer block. Each expert
    is a SwiGLU FFN, and the router selects the top-k experts per token.
    """

    def __init__(
        self,
        d_model: int,
        d_ff: int,
        n_experts: int = 8,
        top_k: int = 2,
        capacity_factor: float = 1.25
    ):
        super().__init__()
        self.n_experts = n_experts
        self.top_k = top_k
        self.capacity_factor = capacity_factor

        # Router: linear projection to expert logits
        self.router = nn.Linear(d_model, n_experts, bias=False)

        # Expert FFNs (each is a SwiGLU)
        self.experts = nn.ModuleList([
            SwiGLU(d_model, d_ff) for _ in range(n_experts)
        ])

    def forward(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass with top-k routing.

        Args:
            x: Input tensor of shape (batch, seq_len, d_model).

        Returns:
            output: MoE output, same shape as input.
            aux_loss: Load-balancing auxiliary loss (scalar).
        """
        B, T, D = x.shape
        x_flat = x.view(-1, D)  # (B*T, D)
        num_tokens = x_flat.shape[0]

        # Compute router logits and probabilities
        router_logits = self.router(x_flat)  # (B*T, n_experts)
        router_probs = F.softmax(router_logits, dim=-1)

        # Select top-k experts per token
        top_k_weights, top_k_indices = torch.topk(
            router_probs, self.top_k, dim=-1
        )
        # Renormalize weights over selected experts
        top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)

        # Compute auxiliary load-balancing loss
        aux_loss = self._load_balancing_loss(router_probs, top_k_indices)

        # Compute expert outputs
        output = torch.zeros_like(x_flat)

        for i in range(self.n_experts):
            # Find tokens routed to this expert (across any of the top-k slots)
            expert_mask = (top_k_indices == i)  # (num_tokens, top_k)
            token_mask = expert_mask.any(dim=-1)  # (num_tokens,)

            if not token_mask.any():
                continue

            # Get the tokens for this expert
            expert_input = x_flat[token_mask]  # (n_selected, D)

            # Compute expert output
            expert_output = self.experts[i](expert_input)  # (n_selected, D)

            # Get the routing weight for this expert
            # For each selected token, sum the weights assigned to this expert
            # (a token could select the same expert in multiple top-k slots,
            #  though this is rare)
            weights = (expert_mask[token_mask] * top_k_weights[token_mask])
            weight = weights.sum(dim=-1, keepdim=True)  # (n_selected, 1)

            output[token_mask] += weight * expert_output

        return output.view(B, T, D), aux_loss

    def _load_balancing_loss(
        self,
        router_probs: torch.Tensor,
        top_k_indices: torch.Tensor
    ) -> torch.Tensor:
        """Compute the auxiliary load-balancing loss.

        Encourages uniform distribution of tokens across experts.
        """
        num_tokens = router_probs.shape[0]

        # f_i: fraction of tokens routed to each expert
        # Count how many tokens have expert i in their top-k
        expert_counts = torch.zeros(
            self.n_experts, device=router_probs.device
        )
        for k in range(self.top_k):
            expert_counts.scatter_add_(
                0, top_k_indices[:, k],
                torch.ones(num_tokens, device=router_probs.device)
            )
        f = expert_counts / num_tokens  # (n_experts,)

        # P_i: average router probability for each expert
        P = router_probs.mean(dim=0)  # (n_experts,)

        # Auxiliary loss: N * sum(f_i * P_i)
        aux_loss = self.n_experts * (f * P).sum()

        return aux_loss

Expert Parallelism

MoE introduces a new dimension of parallelism in distributed training. Since different experts are independent modules, they can be placed on different devices:

Parallelism StrategyWhat Is DistributedCommunication Pattern
Data ParallelismSame model, different batchesAllReduce on gradients
Tensor ParallelismWeight matrices split across devicesAllReduce on activations
Pipeline ParallelismDifferent layers on different devicesPoint-to-point between stages
Expert ParallelismDifferent experts on different devicesAll-to-All for token routing

Expert parallelism uses All-to-All communication: tokens are dispatched from their "home" device to whichever device hosts their selected expert, processed there, and then sent back. The communication cost is proportional to the number of tokens that cross device boundaries, which depends on the routing distribution.

In Mixtral, 8 experts can be placed on 8 GPUs with one expert per GPU. Each GPU holds the full attention weights (shared across all tokens) plus one expert FFN. During the FFN computation:

  1. Each GPU computes router logits for its local tokens.
  2. All-to-All communication dispatches each token to the GPU(s) holding its selected expert(s).
  3. Each GPU processes the tokens assigned to its local expert.
  4. All-to-All communication returns the expert outputs to the originating GPUs.
  5. Each GPU combines the expert outputs with the routing weights.

MoE Design Decisions

Several design choices significantly impact MoE performance:

Number of experts. More experts means more total capacity for the same active parameter count, but also more routing complexity and potential for load imbalance. Mixtral uses 8 experts; some research models have explored 64 or more.

Top-k value. k=1k=1 is most efficient but can be unstable during training. k=2k=2 is the most common choice, providing redundancy in routing while maintaining efficiency. Higher kk values approach dense computation and reduce the efficiency benefit.

Which layers use MoE. Not every layer needs to be an MoE layer. Some architectures apply MoE only to every other layer, or only to deeper layers. The Mixtral architecture applies MoE to every layer.

Expert granularity. An alternative to having a few large experts is to have many small experts with higher top-k. Deepseek-MoE explores "fine-grained experts" where each expert is smaller but more experts are selected, potentially leading to more specialized routing.

Shared Expert Variants

Deepseek-V2 introduced an interesting variant: a shared expert that processes every token, combined with routed experts that are selected per-token:

MoEshared(x)=Eshared(x)+i=1NG(x)iEi(x)\text{MoE}_{shared}(x) = E_{shared}(x) + \sum_{i=1}^{N} G(x)_i \cdot E_i(x)

The shared expert captures common patterns that are useful for all tokens, while the routed experts handle specialized knowledge. This can improve the stability of training and reduce the impact of routing errors.

5. Putting It All Together: The Modern FFN Block

A complete transformer block in a modern dense model (like LLaMA 3) uses:

class ModernTransformerBlock(nn.Module):
    """Complete transformer block following the LLaMA 3 architecture.

    Combines:
    - Pre-LN with RMSNorm (Part 2)
    - Grouped-Query Attention with RoPE (Part 3)
    - SwiGLU FFN (Part 4)
    """

    def __init__(
        self,
        d_model: int = 4096,
        n_heads: int = 32,
        n_kv_heads: int = 8,
        d_ff: int | None = None,
        max_seq_len: int = 8192,
        rope_theta: float = 500000.0
    ):
        super().__init__()
        # RMSNorm for Pre-LN
        self.attn_norm = RMSNorm(d_model)
        self.ffn_norm = RMSNorm(d_model)

        # GQA with RoPE
        self.attention = GroupedQueryAttention(
            d_model=d_model,
            n_heads=n_heads,
            n_kv_heads=n_kv_heads,
            max_seq_len=max_seq_len,
            rope_theta=rope_theta
        )

        # SwiGLU FFN
        self.ffn = SwiGLU(d_model, d_ff)

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int = 0,
        mask: torch.Tensor | None = None
    ) -> torch.Tensor:
        # Pre-RMSNorm + GQA + Residual
        h = self.attn_norm(x)
        h = self.attention(h, start_pos=start_pos, mask=mask)
        x = x + h

        # Pre-RMSNorm + SwiGLU + Residual
        h = self.ffn_norm(x)
        h = self.ffn(h)
        x = x + h

        return x

And for an MoE model (like Mixtral), the only change is replacing the SwiGLU with an MoELayer:

class MoETransformerBlock(nn.Module):
    """Transformer block with Mixture of Experts FFN (Mixtral-style)."""

    def __init__(
        self,
        d_model: int = 4096,
        n_heads: int = 32,
        n_kv_heads: int = 8,
        d_ff: int = 14336,
        n_experts: int = 8,
        top_k: int = 2,
        max_seq_len: int = 32768
    ):
        super().__init__()
        self.attn_norm = RMSNorm(d_model)
        self.ffn_norm = RMSNorm(d_model)

        self.attention = GroupedQueryAttention(
            d_model=d_model,
            n_heads=n_heads,
            n_kv_heads=n_kv_heads,
            max_seq_len=max_seq_len
        )

        # MoE replaces the single FFN
        self.moe = MoELayer(
            d_model=d_model,
            d_ff=d_ff,
            n_experts=n_experts,
            top_k=top_k
        )

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        # Attention (same as dense)
        h = self.attn_norm(x)
        h = self.attention(h, mask=mask)
        x = x + h

        # MoE FFN
        h = self.ffn_norm(x)
        h, aux_loss = self.moe(h)
        x = x + h

        return x, aux_loss

Summary: Evolution of the FFN

EraFFN TypeActivationWeight MatricesHidden DimUsed By
2017Standard MLPReLU2 (W1,W2W_1, W_2)4d4dOriginal Transformer
2018-2020Standard MLPGELU2 (W1,W2W_1, W_2)4d4dBERT, GPT-2, GPT-3
2022+Gated (SwiGLU)SiLU/Swish3 (Wg,Wu,WdW_g, W_u, W_d)83d\frac{8}{3}dLLaMA, Mistral, PaLM
2023+MoE + SwiGLUSiLU/SwishN×3N \times 3 + router83d\frac{8}{3}d per expertMixtral, GPT-4 (rumored)

The progression from simple ReLU MLPs to gated SwiGLU networks to Mixture of Experts represents a steady refinement: each generation offers more expressive feature transformations, better gradient flow, or more efficient scaling. The SwiGLU FFN, in particular, has become as fundamental to modern LLM design as the attention mechanism itself.


In the next post, we will explore Part 5: Training Improvements -- the optimization techniques that make training these architectures feasible at scale, including AdamW, learning rate scheduling, mixed-precision training, and gradient checkpointing.


References

  1. Vaswani, A., Shazeer, N., Parmar, N., et al. (2017). "Attention Is All You Need." NeurIPS 2017.
  2. Hendrycks, D. and Gimpel, K. (2016). "Gaussian Error Linear Units (GELUs)." arXiv:1606.08415.
  3. Ramachandran, P., Zoph, B., Le, Q. V. (2017). "Searching for Activation Functions." arXiv:1710.05941.
  4. Dauphin, Y. N., Fan, A., Auli, M., Grangier, D. (2017). "Language Modeling with Gated Convolutional Networks." ICML 2017.
  5. Shazeer, N. (2020). "GLU Variants Improve Transformer." arXiv:2002.05202.
  6. Geva, M., Schuster, R., Berant, J., Levy, O. (2021). "Transformer Feed-Forward Layers Are Key-Value Memories." EMNLP 2021.
  7. Fedus, W., Zoph, B., Shazeer, N. (2022). "Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity." JMLR 2022.
  8. Jiang, A. Q., Sablayrolles, A., Roux, A., et al. (2024). "Mixtral of Experts." arXiv:2401.04088.
  9. Touvron, H., Lavril, T., Izacard, G., et al. (2023). "LLaMA: Open and Efficient Foundation Language Models." arXiv:2302.13971.
  10. Dai, D., Deng, C., Zhao, C., et al. (2024). "DeepSeekMoE: Towards Ultimate Expert Specialization in Mixture-of-Experts Language Models." arXiv:2401.06066.
Share:
SW

Written by Suchinthaka Wanninayaka

AI/ML Researcher exploring semantic communications, diffusion models, and language model systems. Writing about deep learning from theory to production.

Responses

?

No responses yet. Be the first to share your thoughts!