Transformer Deep Dive: Part 6 - Inference Optimization
Training a large language model may cost millions of dollars, but inference --- the process of generating text from a trained model --- accounts for the vast majority of total compute expenditure over a model's lifetime. A model trained once will serve billions of requests. At this scale, a 2x improvement in inference throughput directly halves serving costs, making inference optimization one of the highest-leverage problems in production ML.
In this post, we dissect the key techniques that make LLM serving practical: why autoregressive decoding is memory-bound rather than compute-bound, how the KV-cache eliminates redundant computation (and why it creates its own memory challenges), quantization methods that shrink models by 4-8x with minimal quality loss, speculative decoding that generates multiple tokens per forward pass, and continuous batching with PagedAttention for maximizing GPU utilization.
The Inference Challenge
LLM inference has two distinct phases with very different computational characteristics:
-
Prefill (prompt processing): The entire input prompt is processed in a single forward pass. This phase is compute-bound --- the GPU's arithmetic units are the bottleneck, similar to training. Matrix multiplications operate on full sequence-length tensors, achieving high arithmetic intensity.
-
Decode (token generation): Tokens are generated one at a time, autoregressively. Each step requires loading the entire model from GPU memory to compute a single token's logits. This phase is memory-bandwidth-bound --- the GPU spends most of its time waiting for data to arrive from HBM, not performing arithmetic.
The key metrics for production serving reflect these two phases:
| Metric | Description | Typical Target |
|---|---|---|
| Time to First Token (TTFT) | Latency of the prefill phase | < 500ms |
| Time per Output Token (TPOT) | Latency of each decode step | < 50ms |
| Throughput | Total tokens generated per second across all requests | Maximize |
| Memory Efficiency | Fraction of GPU memory used productively | Maximize |
Why Decoding is Memory-Bound
To understand the memory bottleneck, consider the arithmetic intensity of a single decode step. For a model with parameters, generating one token requires roughly FLOPs (one multiply-add per parameter). With a 70B model, that is FLOPs. Meanwhile, loading all 70B parameters in BF16 requires reading bytes from HBM.
The arithmetic intensity is:
Modern GPUs have compute-to-bandwidth ratios far exceeding 1:
| GPU | HBM Bandwidth (TB/s) | BF16 TFLOPs | Compute:Bandwidth Ratio |
|---|---|---|---|
| A100 80GB SXM | 2.0 | 312 | 156:1 |
| H100 80GB SXM | 3.35 | 990 | 296:1 |
| H200 141GB | 4.8 | 990 | 206:1 |
With a ratio of ~200:1, the GPU can perform 200 FLOPs for every byte loaded, but single-token decoding only needs 1 FLOP per byte. This means the compute units are idle more than 99% of the time during decoding, waiting for parameters to stream from memory. This is the fundamental reason why batching (processing multiple requests simultaneously) and reducing memory transfers (via quantization, caching) are so impactful.
KV-Cache
The Redundant Computation Problem
In autoregressive generation, we produce tokens one at a time. At each step , the self-attention mechanism computes:
The query only contains the new token, but the keys and values include all previous tokens. Without caching, we would recompute the key and value projections for every previous token at every step:
Step 1: Compute K₁, V₁ for token "The"
Step 2: Compute K₁, V₁ for "The", K₂, V₂ for "cat" → K₁,V₁ recomputed!
Step 3: Compute K₁,V₁, K₂,V₂, K₃,V₃ for "The cat sat" → K₁,V₁,K₂,V₂ recomputed!
Step 4: Compute K₁..₃, V₁..₃, K₄, V₄ for "The cat sat on" → K₁..₃,V₁..₃ recomputed!
The total computation for generating tokens scales as , since step processes tokens.
How KV-Cache Works
The solution is straightforward: cache the key and value projections once computed, and reuse them in subsequent steps.
With KV-caching, each decode step only computes , , for the single new token, appends and to the cache, and computes attention using the cached keys and values:
Step 1: Compute K₁, V₁ → Cache: [K₁], [V₁]
Step 2: Compute K₂, V₂ → Cache: [K₁, K₂], [V₁, V₂] → Only new token projected!
Step 3: Compute K₃, V₃ → Cache: [K₁, K₂, K₃], [V₁, V₂, V₃]
Step 4: Compute K₄, V₄ → Cache: [K₁..₄], [V₁..₄]
Total computation for tokens drops from to --- a linear improvement.
Implementation
A production KV-cache implementation pre-allocates memory for the maximum sequence length and fills it incrementally:
import torch
import torch.nn as nn
import torch.nn.functional as F
class KVCache:
"""Pre-allocated KV-Cache for efficient autoregressive decoding.
Pre-allocates tensors for keys and values up to the maximum sequence
length, then fills them incrementally as tokens are generated.
"""
def __init__(self, max_batch_size: int, max_seq_len: int,
n_kv_heads: int, head_dim: int,
dtype=torch.bfloat16, device="cuda"):
self.max_seq_len = max_seq_len
self.cache_k = torch.zeros(
(max_batch_size, n_kv_heads, max_seq_len, head_dim),
dtype=dtype, device=device,
)
self.cache_v = torch.zeros(
(max_batch_size, n_kv_heads, max_seq_len, head_dim),
dtype=dtype, device=device,
)
self.seq_len = 0
def update(self, k_new: torch.Tensor, v_new: torch.Tensor):
"""Append new keys and values to the cache.
Args:
k_new: New key tensor of shape (batch, n_kv_heads, new_seq_len, head_dim)
v_new: New value tensor of shape (batch, n_kv_heads, new_seq_len, head_dim)
Returns:
Full cached keys and values up to the current position.
"""
batch_size, _, new_seq_len, _ = k_new.shape
end_pos = self.seq_len + new_seq_len
assert end_pos <= self.max_seq_len, "Sequence length exceeds cache capacity"
self.cache_k[:batch_size, :, self.seq_len:end_pos, :] = k_new
self.cache_v[:batch_size, :, self.seq_len:end_pos, :] = v_new
self.seq_len = end_pos
return (
self.cache_k[:batch_size, :, :self.seq_len, :],
self.cache_v[:batch_size, :, :self.seq_len, :],
)
def reset(self):
self.seq_len = 0
class CachedAttention(nn.Module):
"""Multi-head attention with KV-cache support for inference."""
def __init__(self, d_model: int, n_heads: int, n_kv_heads: int):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = d_model // n_heads
self.n_rep = n_heads // n_kv_heads # GQA repetition factor
self.q_proj = nn.Linear(d_model, n_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * self.head_dim, d_model, bias=False)
def forward(self, x: torch.Tensor, kv_cache: KVCache = None):
batch, seq_len, _ = x.shape
q = self.q_proj(x).view(batch, seq_len, self.n_heads, self.head_dim)
k = self.k_proj(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)
v = self.v_proj(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)
# Transpose to (batch, heads, seq_len, head_dim)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Update cache and get full key/value history
if kv_cache is not None:
k, v = kv_cache.update(k, v)
# Expand KV heads for GQA: (batch, n_kv_heads, seq, dim) -> (batch, n_heads, seq, dim)
if self.n_rep > 1:
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
# Scaled dot-product attention
scale = self.head_dim ** -0.5
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
# Causal mask (only needed during prefill; during decode, q has length 1)
if seq_len > 1:
causal_mask = torch.triu(
torch.full((seq_len, k.size(-2)), float("-inf"), device=x.device),
diagonal=k.size(-2) - seq_len + 1,
)
attn_weights = attn_weights + causal_mask
attn_weights = F.softmax(attn_weights, dim=-1)
output = torch.matmul(attn_weights, v)
output = output.transpose(1, 2).reshape(batch, seq_len, -1)
return self.o_proj(output)
KV-Cache Memory Analysis
The KV-cache stores two tensors (K and V) per layer, per head, for every token in the sequence. For a model with layers, key-value heads, and head dimension , the cache size per token per batch element is:
For the full sequence of length with batch size :
Let us compute concrete numbers for LLaMA models in BF16 (2 bytes per element):
| Model | Layers | KV per Token | 4K Context | 32K Context | 128K Context | ||
|---|---|---|---|---|---|---|---|
| LLaMA-2 7B | 32 | 32 | 128 | 0.5 MB | 2 GB | 16 GB | 64 GB |
| LLaMA-2 13B | 40 | 40 | 128 | 0.8 MB | 3.1 GB | 25 GB | 100 GB |
| LLaMA-2 70B | 80 | 8 (GQA) | 128 | 0.3 MB | 1.25 GB | 10 GB | 40 GB |
| LLaMA-3 405B | 126 | 8 (GQA) | 128 | 0.5 MB | 2 GB | 16 GB | 64 GB |
Notice that LLaMA-2 70B uses GQA with only 8 KV heads instead of 64, which reduces its KV-cache by 8x compared to full MHA. This is one of the primary motivations for Grouped-Query Attention (as we discussed in Part 3). Without GQA, a 70B model at 128K context would need 320 GB of KV-cache alone --- more than any single GPU can hold.
For a serving scenario with batch size 32 and 4K context, a 70B GQA model needs GB of KV-cache --- roughly half the capacity of an 80 GB A100. The KV-cache, not the model weights, becomes the binding memory constraint for high-throughput serving.
Quantization
The Core Idea
Quantization reduces the numerical precision of model weights (and optionally activations) from 16 or 32 bits to 8, 4, or even lower bits. Since decoding is memory-bandwidth-bound, reducing the size of the weights proportionally increases the number of tokens we can generate per second.
The basic formulation for uniform affine quantization maps a floating-point tensor to -bit integers:
where is the scale and is the zero-point. Dequantization recovers an approximation:
The quantization error is the rounding error: , which we want to minimize while using as few bits as possible.
Weight-Only vs. Weight-and-Activation Quantization
For LLMs, weight-only quantization is the dominant approach. Weights are static (they do not change between requests) and can be carefully quantized offline. Activations, on the other hand, vary with each input and contain outliers that make quantization harder.
| Approach | What is Quantized | Compute Kernel | Typical Formats |
|---|---|---|---|
| Weight-only | Weights (offline) | W_int x A_fp16 mixed-precision matmul | INT4, INT8, NF4 |
| Weight + Activation | Both (online) | W_int x A_int integer matmul | INT8 x INT8 |
Weight-only INT4 quantization reduces memory by 4x while keeping activations in FP16/BF16, achieving the memory savings needed to fit large models on fewer GPUs while maintaining most of the model quality.
Group Quantization
Quantizing an entire weight matrix with a single scale and zero-point is too coarse --- the outlier values force a wide range, wasting bits on the majority of values that cluster near zero. Group quantization divides each row (or column) into groups of elements (typically ) and assigns a separate scale and zero-point to each group:
import torch
def quantize_symmetric(weight: torch.Tensor, bits: int = 4,
group_size: int = 128) -> tuple:
"""Symmetric group quantization.
Divides each row into groups and quantizes each group independently
with its own scale factor. Symmetric quantization centers around zero,
so no zero-point is needed.
Args:
weight: FP16/BF16 weight tensor of shape (out_features, in_features)
bits: Number of quantization bits
group_size: Number of elements per quantization group
Returns:
Tuple of (quantized_weight, scales)
"""
out_features, in_features = weight.shape
assert in_features % group_size == 0
# Reshape into groups: (out_features, n_groups, group_size)
n_groups = in_features // group_size
weight_grouped = weight.reshape(out_features, n_groups, group_size)
# Compute per-group scale (symmetric: max absolute value)
qmax = 2 ** (bits - 1) - 1 # e.g., 7 for 4-bit
scales = weight_grouped.abs().amax(dim=-1, keepdim=True) / qmax
scales = scales.clamp(min=1e-10) # Avoid division by zero
# Quantize
weight_int = torch.round(weight_grouped / scales).clamp(-qmax, qmax).to(torch.int8)
# Reshape back
weight_int = weight_int.reshape(out_features, in_features)
scales = scales.reshape(out_features, n_groups)
return weight_int, scales
def dequantize_symmetric(weight_int: torch.Tensor, scales: torch.Tensor,
group_size: int = 128) -> torch.Tensor:
"""Dequantize a symmetric group-quantized weight tensor."""
out_features, in_features = weight_int.shape
n_groups = in_features // group_size
weight_grouped = weight_int.reshape(out_features, n_groups, group_size).float()
scales_expanded = scales.unsqueeze(-1)
return (weight_grouped * scales_expanded).reshape(out_features, in_features)
The storage overhead of the scale factors is small: for group size 128 and 4-bit quantization, each group of 128 values (64 bytes in INT4) stores one FP16 scale (2 bytes), adding ~3% overhead.
GPTQ: Optimal Rounding via Second-Order Information
GPTQ (Frantar et al., 2022) goes beyond naive rounding by using second-order information (the Hessian of the layer-wise reconstruction error) to determine the optimal rounding direction for each weight. The key insight is that rounding a single weight up vs. down affects the entire output of the layer, and the Hessian tells us which direction minimizes that impact.
GPTQ processes weights one column at a time, and for each weight, it:
- Rounds the weight to the nearest quantized value.
- Compensates for the rounding error by adjusting the remaining (not yet quantized) weights in the same row, using the inverse Hessian to determine the optimal compensation.
This produces significantly better quantized models than naive rounding, especially at 4-bit and 3-bit precision.
AWQ: Activation-Aware Weight Quantization
AWQ (Lin et al., 2024) observes that not all weights are equally important. Weights corresponding to large activation magnitudes contribute more to the output and should be quantized more carefully. AWQ identifies "salient" weight channels by analyzing activation statistics from a small calibration dataset, then applies per-channel scaling to protect these important weights before quantization.
The scaling is chosen to minimize the quantization error for the most important channels, at the cost of slightly higher error for less important ones --- a tradeoff that consistently improves quality.
Quantization Landscape Summary
| Method | Bits | Calibration Data | Key Idea | Typical Quality (vs FP16) |
|---|---|---|---|---|
| Round-to-Nearest (RTN) | 8 / 4 | None | Naive rounding | Good at INT8, poor at INT4 |
| GPTQ | 4 / 3 | ~128 samples | Hessian-based optimal rounding | ≤1% degradation at INT4 |
| AWQ | 4 | ~128 samples | Activation-aware channel scaling | ≤1% degradation at INT4 |
| SqueezeLLM | 4 / 3 | ~128 samples | Non-uniform quantization | Competitive with GPTQ |
| NF4 (QLoRA) | 4 | None | Normal-float quantization | Designed for fine-tuning |
| QuIP# | 2 / 4 | ~128 samples | Incoherence processing + lattice codes | Best quality at 2-bit |
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch
# INT4 quantization with bitsandbytes (NF4 format)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # Normal Float 4-bit
bnb_4bit_compute_dtype=torch.bfloat16, # Compute in BF16
bnb_4bit_use_double_quant=True, # Quantize the quantization constants too
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf",
quantization_config=quantization_config,
device_map="auto", # Automatically shard across available GPUs
torch_dtype=torch.bfloat16,
)
# 70B model in ~35 GB instead of ~140 GB
Speculative Decoding
The Insight
As we established, autoregressive decoding is memory-bound: the GPU loads all model weights to generate a single token. The compute units are vastly underutilized. Speculative decoding (Leviathan et al., 2023; Chen et al., 2023) exploits this idle compute by generating multiple candidate tokens cheaply, then verifying them in parallel with the large model.
The key property that makes speculative decoding work is that verification is free in autoregressive models. If we have candidate tokens, we can compute the target model's probability for all tokens in a single forward pass (the same forward pass we would need for the next token anyway, just with extra tokens in the input). The marginal cost of verification is negligible compared to a fresh forward pass for each token.
The Algorithm in Detail
Speculative decoding proceeds in rounds. Each round:
-
Draft phase: A small, fast draft model (e.g., a 1B model when the target is 70B) generates candidate tokens autoregressively. This is fast because the draft model is small.
-
Verification phase: The target model processes the entire candidate sequence in a single forward pass, producing probability distributions for each position.
-
Accept/reject: For each candidate token, we compare the draft model's probability with the target model's probability . We accept the token with probability . This rejection sampling scheme guarantees that the final output distribution is identical to the target model's distribution --- speculative decoding is lossless.
-
Correction: At the first rejected position, we sample from an adjusted distribution to correct for the draft model's bias.
import torch
import torch.nn.functional as F
def speculative_decode(
draft_model,
target_model,
input_ids: torch.Tensor,
max_new_tokens: int,
draft_length: int = 5,
temperature: float = 1.0,
):
"""Speculative decoding with rejection sampling.
Generates tokens that are distributed identically to sampling from the
target model, but potentially much faster by amortizing the target
model's forward pass over multiple tokens.
Args:
draft_model: Small, fast model for generating candidates.
target_model: Large model whose distribution we want to sample from.
input_ids: Input token IDs, shape (1, seq_len).
max_new_tokens: Maximum tokens to generate.
draft_length: Number of speculative tokens per round (K).
temperature: Sampling temperature.
Returns:
Generated token IDs.
"""
generated = input_ids.clone()
tokens_generated = 0
while tokens_generated < max_new_tokens:
# --- Draft phase ---
# Generate K candidate tokens with the small model
draft_tokens = []
draft_probs = []
draft_input = generated.clone()
for _ in range(draft_length):
with torch.no_grad():
logits = draft_model(draft_input).logits[:, -1, :]
probs = F.softmax(logits / temperature, dim=-1)
token = torch.multinomial(probs, num_samples=1)
draft_tokens.append(token)
draft_probs.append(probs)
draft_input = torch.cat([draft_input, token], dim=-1)
draft_tokens = torch.cat(draft_tokens, dim=-1) # (1, K)
# --- Verification phase ---
# Single forward pass through target model for all K+1 positions
candidate_seq = torch.cat([generated, draft_tokens], dim=-1)
with torch.no_grad():
target_logits = target_model(candidate_seq).logits
# Extract target probabilities at the K draft positions
# Position indices: last K+1 positions of the output
start_pos = generated.size(1) - 1
target_probs_all = F.softmax(
target_logits[:, start_pos:start_pos + draft_length + 1, :] / temperature,
dim=-1,
)
# --- Accept/Reject ---
n_accepted = 0
for i in range(draft_length):
token_id = draft_tokens[0, i].item()
p_target = target_probs_all[0, i, token_id].item()
p_draft = draft_probs[i][0, token_id].item()
# Rejection sampling: accept with probability min(1, p/q)
acceptance_prob = min(1.0, p_target / max(p_draft, 1e-10))
if torch.rand(1).item() < acceptance_prob:
n_accepted += 1
else:
# Reject: sample from the adjusted distribution
adjusted = torch.clamp(
target_probs_all[0, i, :] - draft_probs[i][0, :],
min=0.0,
)
adjusted = adjusted / adjusted.sum()
correction_token = torch.multinomial(adjusted, num_samples=1)
generated = torch.cat([
generated, draft_tokens[:, :i], correction_token.unsqueeze(0)
], dim=-1)
tokens_generated += i + 1
break
else:
# All K tokens accepted! Sample one more from the target at position K+1
bonus_token = torch.multinomial(
target_probs_all[0, draft_length, :].unsqueeze(0), num_samples=1
)
generated = torch.cat([generated, draft_tokens, bonus_token], dim=-1)
tokens_generated += draft_length + 1
return generated
Expected Speedup
If each draft token has acceptance probability (determined by how well the draft model approximates the target), the expected number of tokens generated per verification round with draft tokens is:
The wallclock speedup depends on the relative cost of the draft and target models. If the draft model runs in time where , the speedup is approximately:
| Acceptance Rate () | K=4 Tokens/Round | K=8 Tokens/Round | Approx. Speedup (K=5, c=0.05) |
|---|---|---|---|
| 0.5 | 1.94 | 2.00 | 1.5x |
| 0.7 | 2.53 | 2.79 | 2.0x |
| 0.8 | 3.00 | 3.57 | 2.4x |
| 0.9 | 3.69 | 5.70 | 2.9x |
| 0.95 | 4.24 | 7.30 | 3.4x |
In practice, speculative decoding achieves 2-3x speedup for code generation and translation tasks (where the draft model is a good predictor), and 1.3-1.8x for more creative/open-ended generation.
Draft Model Selection
The choice of draft model significantly impacts the acceptance rate and overall speedup:
| Strategy | Draft Model | Typical | Notes |
|---|---|---|---|
| Smaller version | LLaMA-1B for LLaMA-70B | 0.7-0.85 | Most common approach |
| Quantized target | INT4 version of target | 0.8-0.9 | High acceptance, but still expensive |
| N-gram / lookup | Token frequency table | 0.3-0.5 | Nearly free, low acceptance |
| Medusa heads | Extra prediction heads on target | 0.6-0.8 | No separate model needed |
| EAGLE | Feature-level draft | 0.7-0.85 | Predicts hidden states, not tokens |
Continuous Batching
The Static Batching Problem
In traditional (static) batching, a batch of requests is assembled, processed until the longest request finishes, and then the batch is released. Short requests that finish early waste GPU cycles while padding to match the longest request:
Static Batch:
Request A: [============]
Request B: [====] ← GPU idle for 67% of the time
Request C: [========] ← GPU idle for 33% of the time
↑ All must wait for A to finish before new requests can begin
If requests have variable output lengths (which they always do), static batching wastes 30-70% of GPU capacity on padding.
Continuous Batching (Iteration-Level Scheduling)
Continuous batching (Yu et al., 2022) operates at the granularity of individual decode steps rather than complete requests. After each decode iteration, finished requests are evicted and new requests are inserted:
Continuous Batching:
Time → T0 T1 T2 T3 T4 T5 T6 T7 T8
Slot 0: A A A A D D F F F
Slot 1: B B C C C E E G G
Slot 2: C C C D D D D G G
Requests A,B finish → slots reused by D,E immediately
No wasted GPU cycles on padding!
Continuous batching can increase throughput by 2-5x compared to static batching, depending on the variance in output lengths.
PagedAttention and vLLM
The memory management challenge in continuous batching is the KV-cache. Each active request has a KV-cache that grows with each generated token. Requests start and end at different times, creating fragmentation in GPU memory --- analogous to the memory fragmentation problem in operating systems.
PagedAttention (Kwon et al., 2023), introduced in vLLM, solves this by borrowing the concept of virtual memory paging. Instead of allocating a single contiguous block for each request's KV-cache, PagedAttention divides the cache into fixed-size blocks (e.g., 16 tokens per block) and maps them through a block table:
Request A's KV-Cache (logical): [Block 0][Block 1][Block 2][Block 3]
↓ ↓ ↓ ↓
Physical GPU memory pages: [Page 7 ][Page 2 ][Page 13][Page 4 ]
↑ ↑
Request B's KV-Cache (logical): [Block 0][Block 1][Block 2]
↓ ↓ ↓
Physical GPU memory pages: [Page 1 ][Page 9 ][Page 11]
This provides several benefits:
- No internal fragmentation: Memory is allocated in small fixed-size pages, not large contiguous blocks. Waste is at most one page per request.
- No external fragmentation: Pages can be allocated from anywhere in GPU memory, unlike contiguous allocation which suffers from fragmentation.
- Memory sharing: Requests that share a common prefix (e.g., system prompt) can share KV-cache pages via copy-on-write, dramatically reducing memory for batched serving with shared prompts.
- Near-optimal utilization: vLLM reports KV-cache memory utilization above 96%, compared to 20-50% for static allocation.
from vllm import LLM, SamplingParams
# Initialize vLLM with PagedAttention
llm = LLM(
model="meta-llama/Llama-2-70b-chat-hf",
tensor_parallel_size=4, # Spread across 4 GPUs
max_num_seqs=256, # Max concurrent sequences
max_num_batched_tokens=8192, # Max tokens per iteration
gpu_memory_utilization=0.90, # Use 90% of GPU memory for KV-cache
quantization="awq", # Optional: combine with quantization
dtype="bfloat16",
)
# Serve multiple requests efficiently
prompts = [
"Explain the theory of relativity in simple terms.",
"Write a Python function to sort a linked list.",
"What is the capital of France?",
# ... hundreds more requests
]
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.9,
max_tokens=512,
)
# vLLM handles continuous batching, KV-cache management,
# and PagedAttention automatically
outputs = llm.generate(prompts, sampling_params)
Production Serving Stack
Optimization Pipeline
A production deployment typically stacks multiple optimizations:
Trained Model (FP16/BF16)
↓
Quantization (AWQ/GPTQ to INT4) → 4x memory reduction
↓
Tensor Parallelism (across GPUs) → Distribute model across nodes
↓
Continuous Batching + PagedAttention → Maximize throughput
↓
Speculative Decoding (optional) → Reduce per-request latency
↓
Production Serving Engine (vLLM, TRT-LLM)
Inference Engine Comparison
| Engine | PagedAttention | Speculative Decoding | Quantization | Multi-GPU | Best For |
|---|---|---|---|---|---|
| vLLM | Yes (native) | Yes | AWQ, GPTQ, FP8 | TP, PP | High-throughput serving |
| TensorRT-LLM | Yes | Yes | INT4, INT8, FP8 | TP, PP | Lowest latency on NVIDIA |
| SGLang | Yes (RadixAttention) | Yes | AWQ, GPTQ | TP | Structured generation |
| llama.cpp | Partial | No | GGUF (2-8 bit) | Limited | CPU/edge deployment |
| TGI | Yes | No | AWQ, GPTQ, bitsandbytes | TP | HuggingFace ecosystem |
| Ollama | Via llama.cpp | No | GGUF | Limited | Local development |
Latency Breakdown
For a typical request with 512 input tokens and 256 output tokens on a LLaMA-70B (INT4, single A100):
| Phase | Time | Percentage | Bottleneck |
|---|---|---|---|
| Prefill (process 512 input tokens) | ~150ms | 12% | Compute-bound |
| Decode (generate 256 output tokens) | ~1000ms (~4ms/token) | 80% | Memory-bandwidth-bound |
| Sampling + post-processing | ~50ms | 4% | CPU |
| Network + scheduling overhead | ~50ms | 4% | I/O |
| Total | ~1250ms | 100% |
Cost Optimization: Choosing the Right Configuration
The optimal serving configuration depends on whether you are optimizing for latency (time per request) or throughput (total tokens per dollar):
| Priority | Strategy | Typical Config |
|---|---|---|
| Lowest latency | Tensor parallelism across many GPUs, speculative decoding, FP8 | 4-8 GPUs per model, small batches |
| Highest throughput | Maximum batch size, INT4 quantization, continuous batching | 1-2 GPUs per model, large batches |
| Lowest cost | Aggressive quantization (INT4), maximize batch size, spot instances | Minimum GPUs, maximum utilization |
Summary
| Technique | What It Does | Speedup / Savings | Tradeoff |
|---|---|---|---|
| KV-Cache | Caches key/value projections across decode steps | compute | Memory proportional to sequence length |
| GQA | Reduces KV heads (see Part 3) | 4-8x KV-cache reduction | Marginal quality impact |
| INT4 Quantization | Reduces weight precision to 4 bits | 4x memory, ~2x throughput | ≤1% quality loss with GPTQ/AWQ |
| Speculative Decoding | Generates multiple tokens per target model pass | 2-3x decode speed | Requires draft model, variable speedup |
| Continuous Batching | Inserts/removes requests at each decode step | 2-5x throughput | Implementation complexity |
| PagedAttention | Paged memory management for KV-cache | >96% memory utilization | Custom CUDA kernels |
Each technique targets a different bottleneck, and they compose multiplicatively. A production stack combining INT4 quantization, continuous batching with PagedAttention, and speculative decoding can serve a 70B model at throughputs that were unimaginable just two years ago --- handling thousands of concurrent users from a single 8-GPU node.
In the next post, we will explore Part 7: Minor But Important Changes --- the seemingly small architectural decisions that collectively make a large difference: removing bias terms, tied vs. untied embeddings, parallel attention and FFN blocks, initialization schemes, and other design patterns found in modern LLMs.
References
- Leviathan, Y., Kalman, M., & Matias, Y. (2023). Fast Inference from Transformers via Speculative Decoding. ICML 2023. arXiv:2211.17192.
- Chen, C., Borgeaud, S., Irving, G., Lespiau, J.-B., Sifre, L., & Jumper, J. (2023). Accelerating Large Language Model Decoding with Speculative Sampling. arXiv:2302.01318.
- Frantar, E., Ashkboos, S., Hoefler, T., & Alistarh, D. (2022). GPTQ: Accurate Post-Training Quantization for Generative Pre-Trained Transformers. ICLR 2023. arXiv:2210.17323.
- Lin, J., Tang, J., Tang, H., Yang, S., Chen, W.-M., Wang, W.-C., ... & Han, S. (2024). AWQ: Activation-aware Weight Quantization for On-Device LLM Compression and Acceleration. MLSys 2024. arXiv:2306.00978.
- Kwon, W., Li, Z., Zhuang, S., Sheng, Y., Zheng, L., Yu, C. H., ... & Stoica, I. (2023). Efficient Memory Management for Large Language Model Serving with PagedAttention. SOSP 2023. arXiv:2309.06180.
- Yu, G.-I., Jeong, J. S., Kim, G.-W., Kim, S., & Chun, B.-G. (2022). Orca: A Distributed Serving System for Transformer-Based Generative Models. OSDI 2022.
- Dettmers, T., Pagnoni, A., Holtzman, A., & Zettlemoyer, L. (2023). QLoRA: Efficient Finetuning of Quantized Language Models. NeurIPS 2023. arXiv:2305.14314.
- Sheng, Y., Zheng, L., Yuan, B., Li, Z., Ryabinin, M., Chen, B., ... & Stoica, I. (2023). FlexGen: High-Throughput Generative Inference of Large Language Models with a Single GPU. ICML 2023. arXiv:2303.06865.
- Pope, R., Douglas, S., Chowdhery, A., et al. (2023). Efficiently Scaling Transformer Inference. MLSys 2023. arXiv:2211.05102.
- Cai, T., Li, Y., Geng, Z., Peng, H., Lee, J. D., Chen, D., & Dao, T. (2024). Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads. ICML 2024. arXiv:2401.10774.
- Li, Y., Cai, T., Zhang, Y., Chen, D., & Dao, T. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. ICML 2024. arXiv:2401.15077.
Written by Suchinthaka Wanninayaka
AI/ML Researcher exploring semantic communications, diffusion models, and language model systems. Writing about deep learning from theory to production.
Continue the Series
Transformer Deep Dive: Part 5 - Training Improvements
25 min read
Next ArticleTransformer Deep Dive: Part 7 - Minor But Important Changes
24 min read
Related Articles
Responses
No responses yet. Be the first to share your thoughts!