Transformer Deep Dive: Part 6 - Inference Optimization
Production deployment techniques - KV-cache for avoiding redundant computation, quantization for memory efficiency, speculative decoding for faster generation, and continuous batching for throughput.
Suchinthaka W.
January 20, 2025 · 7 min read
Training a large language model is expensive, but inference—the process of generating predictions—presents its own unique challenges. When deployed at scale, LLMs must handle latency, throughput, memory efficiency, and cost.
The Inference Challenge
| Challenge | Description | |-----------|-------------| | Latency | Time to first token, time per token | | Throughput | Requests per second | | Memory | Model weights + KV cache | | Cost | GPU hours per 1M tokens |
Memory vs Compute Bound
Autoregressive LLM inference is typically memory-bound during the decode phase:
- Each token generation requires loading the entire model weights
- The arithmetic intensity (FLOPs per byte loaded) is very low
- Modern GPUs have much higher compute throughput than memory bandwidth
| GPU | Memory BW (TB/s) | FP16 TFLOPs | Ratio | |-----|------------------|-------------|-------| | A100 (80GB) | 2.0 | 312 | 156:1 | | H100 (80GB) | 3.35 | 990 | 296:1 |
For every byte loaded from memory, the GPU can perform 150-300 FP16 operations. But generating one token requires loading all model weights while performing relatively few operations per weight.
KV-Cache
The Redundancy Problem
In autoregressive generation, we generate one token at a time. Without caching, we'd recompute attention for all previous tokens at each step:
Step 1: Compute attention for [A]
Step 2: Compute attention for [A, B] ← Recomputes A!
Step 3: Compute attention for [A, B, C] ← Recomputes A, B!
Step 4: Compute attention for [A, B, C, D] ← Recomputes A, B, C!
Complexity: total operations for n tokens.
KV-Cache Solution
Cache the Key and Value projections for all previous tokens:
Step 1: Compute K₁, V₁ for [A], cache them
Step 2: Use cached K₁,V₁, compute K₂,V₂ for [B], cache
Step 3: Use cached K₁..₂,V₁..₂, compute K₃,V₃ for [C], cache
Step 4: Use cached K₁..₃,V₁..₃, compute K₄,V₄ for [D], cache
Now we only compute Q, K, V for the new token!
Implementation
class KVCache:
def __init__(self, max_batch_size: int, max_seq_len: int,
n_heads: int, head_dim: int, device: str):
self.cache_k = torch.zeros(
(max_batch_size, max_seq_len, n_heads, head_dim), device=device
)
self.cache_v = torch.zeros(
(max_batch_size, max_seq_len, n_heads, head_dim), device=device
)
self.seq_len = 0
def update(self, k: torch.Tensor, v: torch.Tensor):
batch_size, seq_len = k.shape[:2]
self.cache_k[:batch_size, self.seq_len:self.seq_len + seq_len] = k
self.cache_v[:batch_size, self.seq_len:self.seq_len + seq_len] = v
self.seq_len += seq_len
return self.cache_k[:batch_size, :self.seq_len], \
self.cache_v[:batch_size, :self.seq_len]
KV-Cache Memory
For a model with L layers, n heads, and head dimension d:
| Model | 2K Context | 32K Context | 128K Context | |-------|------------|-------------|--------------| | LLaMA 7B | 1 GB | 16 GB | 64 GB | | LLaMA 70B | 10 GB | 160 GB | 640 GB |
This is why GQA (Part 3) is so important—it reduces KV cache by the compression ratio.
Quantization
The Idea
Reduce precision of weights (and activations) to decrease memory and increase throughput.
Common Formats
| Format | Bits | Memory Reduction | Quality Impact | |--------|------|------------------|----------------| | FP16 | 16 | 2× | Minimal | | INT8 | 8 | 4× | Small | | INT4 | 4 | 8× | Moderate | | INT2 | 2 | 16× | Significant |
Weight-Only Quantization
Quantize weights but keep activations in higher precision:
During inference:
Group Quantization
Quantize weights in groups (e.g., 128 elements) with per-group scale/zero:
def quantize_group(weight, group_size=128, bits=4):
# Reshape to groups
weight = weight.reshape(-1, group_size)
# Per-group statistics
scales = weight.abs().max(dim=1, keepdim=True)[0] / (2**(bits-1) - 1)
zeros = weight.min(dim=1, keepdim=True)[0]
# Quantize
weight_int = torch.round((weight - zeros) / scales).to(torch.int8)
return weight_int, scales, zeros
Popular Quantization Methods
| Method | Approach | Quality | |--------|----------|---------| | GPTQ | Layer-wise optimal rounding | High | | AWQ | Activation-aware scaling | High | | GGUF | Multiple quant formats | Variable | | QLoRA | 4-bit base + LoRA adapters | High |
Quantization Recipe
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
# 4-bit quantization with bitsandbytes
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # Normal float 4-bit
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True, # Quantize the quantization constants
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b",
quantization_config=quantization_config,
device_map="auto",
)
Speculative Decoding
The Bottleneck
Autoregressive decoding generates one token at a time. Each step requires:
- Load model weights from memory
- Compute forward pass
- Sample one token
The GPU is idle most of the time waiting for memory!
Speculative Decoding Idea
Use a small, fast draft model to generate multiple candidate tokens, then verify them in parallel with the large target model.
Draft Model (fast): Generate [A, B, C, D, E] speculatively
Target Model (slow): Verify all 5 tokens in ONE forward pass
Accept [A, B, C], reject [D, E]
Return 3 tokens instead of 1!
The Algorithm
- Draft: Small model generates K candidate tokens
- Verify: Large model computes probabilities for all K tokens in parallel
- Accept/Reject: Accept tokens that match target distribution
- Correct: Sample from adjusted distribution for first rejected position
def speculative_decode(draft_model, target_model, prompt, k=5):
# Draft phase: generate k tokens with small model
draft_tokens = []
for _ in range(k):
token = draft_model.sample_next(prompt + draft_tokens)
draft_tokens.append(token)
# Verify phase: single forward pass with large model
target_probs = target_model.get_probs(prompt + draft_tokens)
draft_probs = draft_model.get_probs(prompt + draft_tokens)
# Accept/reject
accepted = []
for i, token in enumerate(draft_tokens):
if random.random() < target_probs[i, token] / draft_probs[i, token]:
accepted.append(token)
else:
# Sample correction token
adjusted_probs = relu(target_probs[i] - draft_probs[i])
accepted.append(sample(adjusted_probs))
break
return accepted
Speedup
Expected tokens per step:
where is the acceptance rate and K is draft length.
| Acceptance Rate | K=4 | K=8 | |-----------------|-----|-----| | 70% | 2.5× | 3.0× | | 80% | 3.0× | 4.0× | | 90% | 3.7× | 5.7× |
Continuous Batching
Static Batching Problem
Traditional batching waits for a batch to fill, processes together:
Request 1: [============================]
Request 2: [==========]
Request 3: [==================]
With static batching, all must wait for longest:
[============================] <- Wasted GPU time
Continuous Batching Solution
Insert new requests as soon as others complete:
Time → T0 T1 T2 T3 T4 T5 T6 T7
Slot 0: R1 R1 R1 R1 R4 R4 R4 R5
Slot 1: R2 R2 R3 R3 R3 R5 R5 R5
Slot 2: R3 R3 R3 R4 R4 R4 R6 R6
PagedAttention (vLLM)
Manage KV cache like virtual memory with pages:
Logical KV Cache: [Block 0][Block 1][Block 2][Block 3]
↓ ↓ ↓ ↓
Physical Memory: [Page 7 ][Page 2 ][Page 9 ][Page 4 ]
Benefits:
- No memory fragmentation
- Efficient memory sharing for beam search
- Near-zero waste
# Using vLLM for high-throughput serving
from vllm import LLM, SamplingParams
llm = LLM(model="meta-llama/Llama-2-70b", tensor_parallel_size=4)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(prompts, sampling_params)
Production Serving Stack
Complete Optimization Pipeline
Model → Quantization → Tensor Parallelism → Continuous Batching
↓
KV-Cache + PagedAttention
↓
Speculative Decoding (optional)
↓
High-throughput serving
Inference Engines
| Engine | Key Features | |--------|-------------| | vLLM | PagedAttention, continuous batching | | TensorRT-LLM | NVIDIA optimization, fastest on NVIDIA | | llama.cpp | CPU/GPU, quantization focus | | TGI | HuggingFace, easy deployment |
Latency Breakdown
For a typical request (512 input, 256 output tokens):
| Phase | Time | % | |-------|------|---| | Prefill (process input) | 200ms | 15% | | Decode (generate output) | 1000ms | 75% | | Network/overhead | 100ms | 10% |
Summary
| Technique | Benefit | Tradeoff | |-----------|---------|----------| | KV-Cache | Avoid recomputation | Memory usage | | Quantization | 4-8× memory reduction | Small quality loss | | Speculative Decoding | 2-4× decode speed | Draft model overhead | | Continuous Batching | Higher throughput | Implementation complexity | | PagedAttention | Efficient memory | Kernel overhead |
In the next post, we'll explore Part 7: Minor But Important Changes - bias removal, tied embeddings, parallel attention+FFN, and initialization schemes.
Transformer Deep Dive: Part 5 - Training Improvements
NextTransformer Deep Dive: Part 7 - Minor But Important Changes
Related Articles
Responses
Be the first to share your thoughts!