All Articles
transformersinferencekv-cachequantizationspeculative-decodingdeployment

Transformer Deep Dive: Part 6 - Inference Optimization

Production deployment techniques - KV-cache for avoiding redundant computation, quantization for memory efficiency, speculative decoding for faster generation, and continuous batching for throughput.

SW

Suchinthaka W.

January 20, 2025 · 7 min read

Training a large language model is expensive, but inference—the process of generating predictions—presents its own unique challenges. When deployed at scale, LLMs must handle latency, throughput, memory efficiency, and cost.

The Inference Challenge

| Challenge | Description | |-----------|-------------| | Latency | Time to first token, time per token | | Throughput | Requests per second | | Memory | Model weights + KV cache | | Cost | GPU hours per 1M tokens |

Memory vs Compute Bound

Autoregressive LLM inference is typically memory-bound during the decode phase:

  1. Each token generation requires loading the entire model weights
  2. The arithmetic intensity (FLOPs per byte loaded) is very low
  3. Modern GPUs have much higher compute throughput than memory bandwidth

| GPU | Memory BW (TB/s) | FP16 TFLOPs | Ratio | |-----|------------------|-------------|-------| | A100 (80GB) | 2.0 | 312 | 156:1 | | H100 (80GB) | 3.35 | 990 | 296:1 |

For every byte loaded from memory, the GPU can perform 150-300 FP16 operations. But generating one token requires loading all model weights while performing relatively few operations per weight.

KV-Cache

The Redundancy Problem

In autoregressive generation, we generate one token at a time. Without caching, we'd recompute attention for all previous tokens at each step:

Step 1: Compute attention for [A]
Step 2: Compute attention for [A, B]        ← Recomputes A!
Step 3: Compute attention for [A, B, C]     ← Recomputes A, B!
Step 4: Compute attention for [A, B, C, D]  ← Recomputes A, B, C!

Complexity: O(n2)O(n^2) total operations for n tokens.

KV-Cache Solution

Cache the Key and Value projections for all previous tokens:

Step 1: Compute K₁, V₁ for [A], cache them
Step 2: Use cached K₁,V₁, compute K₂,V₂ for [B], cache
Step 3: Use cached K₁..₂,V₁..₂, compute K₃,V₃ for [C], cache
Step 4: Use cached K₁..₃,V₁..₃, compute K₄,V₄ for [D], cache

Now we only compute Q, K, V for the new token!

Implementation

class KVCache:
    def __init__(self, max_batch_size: int, max_seq_len: int,
                 n_heads: int, head_dim: int, device: str):
        self.cache_k = torch.zeros(
            (max_batch_size, max_seq_len, n_heads, head_dim), device=device
        )
        self.cache_v = torch.zeros(
            (max_batch_size, max_seq_len, n_heads, head_dim), device=device
        )
        self.seq_len = 0

    def update(self, k: torch.Tensor, v: torch.Tensor):
        batch_size, seq_len = k.shape[:2]
        self.cache_k[:batch_size, self.seq_len:self.seq_len + seq_len] = k
        self.cache_v[:batch_size, self.seq_len:self.seq_len + seq_len] = v
        self.seq_len += seq_len
        return self.cache_k[:batch_size, :self.seq_len], \
               self.cache_v[:batch_size, :self.seq_len]

KV-Cache Memory

For a model with L layers, n heads, and head dimension d:

KV Cache=2×L×nheads×dhead×seq_len×batch_size\text{KV Cache} = 2 \times L \times n_{heads} \times d_{head} \times \text{seq\_len} \times \text{batch\_size}

| Model | 2K Context | 32K Context | 128K Context | |-------|------------|-------------|--------------| | LLaMA 7B | 1 GB | 16 GB | 64 GB | | LLaMA 70B | 10 GB | 160 GB | 640 GB |

This is why GQA (Part 3) is so important—it reduces KV cache by the compression ratio.

Quantization

The Idea

Reduce precision of weights (and activations) to decrease memory and increase throughput.

Common Formats

| Format | Bits | Memory Reduction | Quality Impact | |--------|------|------------------|----------------| | FP16 | 16 | 2× | Minimal | | INT8 | 8 | 4× | Small | | INT4 | 4 | 8× | Moderate | | INT2 | 2 | 16× | Significant |

Weight-Only Quantization

Quantize weights but keep activations in higher precision:

Wint=round(Wmin(W)max(W)min(W)×(2b1))W_{int} = \text{round}\left(\frac{W - \min(W)}{\max(W) - \min(W)} \times (2^b - 1)\right)

During inference:

Y=Xdequantize(Wint)Y = X \cdot \text{dequantize}(W_{int})

Group Quantization

Quantize weights in groups (e.g., 128 elements) with per-group scale/zero:

def quantize_group(weight, group_size=128, bits=4):
    # Reshape to groups
    weight = weight.reshape(-1, group_size)

    # Per-group statistics
    scales = weight.abs().max(dim=1, keepdim=True)[0] / (2**(bits-1) - 1)
    zeros = weight.min(dim=1, keepdim=True)[0]

    # Quantize
    weight_int = torch.round((weight - zeros) / scales).to(torch.int8)

    return weight_int, scales, zeros

Popular Quantization Methods

| Method | Approach | Quality | |--------|----------|---------| | GPTQ | Layer-wise optimal rounding | High | | AWQ | Activation-aware scaling | High | | GGUF | Multiple quant formats | Variable | | QLoRA | 4-bit base + LoRA adapters | High |

Quantization Recipe

from transformers import AutoModelForCausalLM, BitsAndBytesConfig

# 4-bit quantization with bitsandbytes
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",  # Normal float 4-bit
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,  # Quantize the quantization constants
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b",
    quantization_config=quantization_config,
    device_map="auto",
)

Speculative Decoding

The Bottleneck

Autoregressive decoding generates one token at a time. Each step requires:

  1. Load model weights from memory
  2. Compute forward pass
  3. Sample one token

The GPU is idle most of the time waiting for memory!

Speculative Decoding Idea

Use a small, fast draft model to generate multiple candidate tokens, then verify them in parallel with the large target model.

Draft Model (fast):    Generate [A, B, C, D, E] speculatively
Target Model (slow):   Verify all 5 tokens in ONE forward pass
                       Accept [A, B, C], reject [D, E]
                       Return 3 tokens instead of 1!

The Algorithm

  1. Draft: Small model generates K candidate tokens
  2. Verify: Large model computes probabilities for all K tokens in parallel
  3. Accept/Reject: Accept tokens that match target distribution
  4. Correct: Sample from adjusted distribution for first rejected position
def speculative_decode(draft_model, target_model, prompt, k=5):
    # Draft phase: generate k tokens with small model
    draft_tokens = []
    for _ in range(k):
        token = draft_model.sample_next(prompt + draft_tokens)
        draft_tokens.append(token)

    # Verify phase: single forward pass with large model
    target_probs = target_model.get_probs(prompt + draft_tokens)
    draft_probs = draft_model.get_probs(prompt + draft_tokens)

    # Accept/reject
    accepted = []
    for i, token in enumerate(draft_tokens):
        if random.random() < target_probs[i, token] / draft_probs[i, token]:
            accepted.append(token)
        else:
            # Sample correction token
            adjusted_probs = relu(target_probs[i] - draft_probs[i])
            accepted.append(sample(adjusted_probs))
            break

    return accepted

Speedup

Expected tokens per step:

E[tokens per step]=1αK+11αE[\text{tokens per step}] = \frac{1 - \alpha^{K+1}}{1 - \alpha}

where α\alpha is the acceptance rate and K is draft length.

| Acceptance Rate | K=4 | K=8 | |-----------------|-----|-----| | 70% | 2.5× | 3.0× | | 80% | 3.0× | 4.0× | | 90% | 3.7× | 5.7× |

Continuous Batching

Static Batching Problem

Traditional batching waits for a batch to fill, processes together:

Request 1: [============================]
Request 2: [==========]
Request 3: [==================]

With static batching, all must wait for longest:
[============================] <- Wasted GPU time

Continuous Batching Solution

Insert new requests as soon as others complete:

Time  →   T0   T1   T2   T3   T4   T5   T6   T7
Slot 0:   R1   R1   R1   R1   R4   R4   R4   R5
Slot 1:   R2   R2   R3   R3   R3   R5   R5   R5
Slot 2:   R3   R3   R3   R4   R4   R4   R6   R6

PagedAttention (vLLM)

Manage KV cache like virtual memory with pages:

Logical KV Cache:  [Block 0][Block 1][Block 2][Block 3]
                      ↓        ↓        ↓        ↓
Physical Memory:   [Page 7 ][Page 2 ][Page 9 ][Page 4 ]

Benefits:

  • No memory fragmentation
  • Efficient memory sharing for beam search
  • Near-zero waste
# Using vLLM for high-throughput serving
from vllm import LLM, SamplingParams

llm = LLM(model="meta-llama/Llama-2-70b", tensor_parallel_size=4)

sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(prompts, sampling_params)

Production Serving Stack

Complete Optimization Pipeline

Model → Quantization → Tensor Parallelism → Continuous Batching
  ↓
KV-Cache + PagedAttention
  ↓
Speculative Decoding (optional)
  ↓
High-throughput serving

Inference Engines

| Engine | Key Features | |--------|-------------| | vLLM | PagedAttention, continuous batching | | TensorRT-LLM | NVIDIA optimization, fastest on NVIDIA | | llama.cpp | CPU/GPU, quantization focus | | TGI | HuggingFace, easy deployment |

Latency Breakdown

For a typical request (512 input, 256 output tokens):

| Phase | Time | % | |-------|------|---| | Prefill (process input) | 200ms | 15% | | Decode (generate output) | 1000ms | 75% | | Network/overhead | 100ms | 10% |

Summary

| Technique | Benefit | Tradeoff | |-----------|---------|----------| | KV-Cache | Avoid recomputation | Memory usage | | Quantization | 4-8× memory reduction | Small quality loss | | Speculative Decoding | 2-4× decode speed | Draft model overhead | | Continuous Batching | Higher throughput | Implementation complexity | | PagedAttention | Efficient memory | Kernel overhead |


In the next post, we'll explore Part 7: Minor But Important Changes - bias removal, tied embeddings, parallel attention+FFN, and initialization schemes.

Did you find this helpful?
Share:

Responses

Be the first to share your thoughts!