Introduction
Autoregressive language models generate text one token at a time. Each new prediction requires a full forward pass through all transformer layers, leading to redundant computations.
For example, generating the next token in:
[What, is, in,] → [the]
requires recomputing attention over [What, is, in,]
even though these tokens haven’t changed.
KV Caching solves this inefficiency by storing and reusing intermediate computations. In this post, we’ll:
-
Revisit transformer attention mechanics.
-
Identify where redundancy occurs.
-
Implement KV Caching in nanoVLM (a minimal VLM built with PyTorch).
-
Benchmark the speedup (38
Revisiting Transformer Attention
A transformer layer consists of:
-
Multi-head self-attention
-
Feed-forward network (MLP)
-
Residual connections & layer norm
Self-attention computes:
-
Queries (Q), Keys (K), Values (V) from input embeddings.
-
Attention scores via
softmax(QKᵀ / √dₖ)
. -
Output as a weighted sum of
V
.
Here’s a minimal PyTorch implementation:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import torch import torch.nn.functional as F # Input embeddings (seq_len=5, dim=10) input_emb = torch.randn(5, 10) # Project to Q, K, V Q = input_emb @ W_q K = input_emb @ W_k V = input_emb @ W_v # Causal masking (no peeking ahead) mask = torch.tril(torch.ones(5, 5)) scores = (Q @ K.T).masked_fill(mask == 0, -torch.inf) output = F.softmax(scores, dim=-1) @ V |
Where Redundancy Creeps In
During autoregressive generation:
-
The model predicts
tᵢ₊₁
given[t₀...tᵢ]
. -
At each step, it recomputes
K
andV
for the entire sequence—even though only the newest token changes.
Example:
1 2 3 4 5 6 7 8 |
new_token_emb = torch.randn(1, 10) extended_input = torch.cat([input_emb, new_token_emb], dim=0) # K and V for the first 5 tokens are unchanged! K_ext = extended_input @ W_k torch.testing.assert_close(K, K_ext[:5]) # Passes! |
1 2 3 4 5 6 7 8 9 |
Original (5×5): Extended (6×6): ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ → ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ |
■
= Already computed (reused)
□
= Unnecessary recomputation
How KV Caching Fixes It
Instead of recomputing K
and V
for the entire sequence:
-
Cache
K
andV
after the first pass. -
For new tokens, compute only the latest
Kₙₑᵥ
andVₙₑᵥ
. -
Concatenate them with the cached values.
Key Insight:
-
Prefill Phase: Process the full prompt and populate the cache.
-
Decode Phase: Generate tokens incrementally using cached
K/V
.
Implementing KV Cache in nanoVLM
We modified three components:
1. Attention Block (LanguageModelGroupedAttention
)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
def forward(self, x, cos, sin, block_kv_cache=None): is_prefill = block_kv_cache is None q_curr, k_curr, v_curr = project_current_tokens(x) if not is_prefill: # Append new K/V to cache k = torch.cat([block_kv_cache['key'], k_curr], dim=2) v = torch.cat([block_kv_cache['value'], v_curr], dim=2) else: k, v = k_curr, v_curr block_kv_cache = {'key': k, 'value': v} return attention_output, block_kv_cache |
2. Layer-Wise Cache Tracking (LanguageModel
)
1 2 3 4 5 6 |
def forward(self, x, kv_cache=None, start_pos=0): for i, block in enumerate(self.blocks): x, kv_cache[i] = block(x, cos, sin, kv_cache[i]) return x, kv_cache |
3. Two-Phase Generation (VisionLanguageModel
)
1 2 3 4 5 6 7 8 9 |
# Prefill: Process prompt and build cache prompt_out, kv_cache = self.forward(prompt, kv_cache=None, start_pos=0) # Decode: Generate tokens incrementally for i in range(max_new_tokens): next_token = sample(prompt_out) decode_out, kv_cache = self.forward(next_token, kv_cache, start_pos=i) |
Results & Takeaways
- ✅ 38 Percent faster generation (benchmarked on nanoVLM).
- ✅ Memory-efficient (grows linearly with sequence length).
- ✅ Position-aware (correct rotary embeddings via
start_pos
).
Trade-offs:
-
Slightly more complex code.
-
Restricts some advanced inference methods (e.g., beam search).
Conclusion
KV Caching is a game-changer for autoregressive models. By eliminating redundant computations, it enables faster, longer, and more efficient generation—critical for real-world applications.
Let us know your thoughts in the comments!