L2 — Paged KV cache
2.1 Why paging
The L1 KV is one big contiguous tensor per layer of shape [max_seq_len, H_kv, D]. That works for one request, but:
- Pre-allocates
max_seq_lenper request, even when most stop at 50 tokens. - Can't represent multiple requests in one tensor without padding to the longest one.
- Can't release memory until the request is fully done.
A paged pool fixes all three. KV memory is sliced into fixed-size blocks (16 tokens each). A request that has 73 tokens of history holds ceil(73/16) = 5 blocks; when it finishes, those 5 blocks return to a free list. New requests grab any free block, anywhere.
2.2 Storage choice: token-flat vs block-tensor
Two equivalent ways to lay out the pool:
| token-flat (this lesson) | block-tensor (L3+) | |
|---|---|---|
| per-layer shape | [num_slots, H_kv, D] | [num_blocks, block_size, H_kv, D] |
| indexed by | flat slot id | (block_id, within) |
| writes | direct: K[slot_mapping] = k | flatten via .view(-1, H_kv, D) |
| kernel-friendly? | requires gather to read history | kernel walks block_table directly |
Both layouts use the same number of bytes. We'll use token-flat in L2 (simpler), then switch to block-tensor in L3 when we introduce a paged-aware kernel.
2.3 The three classes
class KvPool:
"""Per-layer K, V tensors. Knows nothing about which slots belong to whom."""
def __init__(self, num_layers, num_slots, num_kv_heads, head_dim, dtype, device):
self.K = [torch.empty(num_slots, num_kv_heads, head_dim, dtype=dtype, device=device)
for _ in range(num_layers)]
self.V = [torch.empty_like(self.K[0]) for _ in range(num_layers)]
def set_kv(self, layer, slot_mapping, k, v):
self.K[layer][slot_mapping] = k
self.V[layer][slot_mapping] = v
class BlockAllocator:
"""Free-list of block IDs. One global allocator across all requests."""
def __init__(self, num_slots, block_size=16):
assert num_slots % block_size == 0
self.block_size = block_size
self.num_blocks = num_slots // block_size
self.free_blocks = list(range(self.num_blocks))
def alloc_blocks(self, n):
if n > len(self.free_blocks):
raise RuntimeError("OOM")
out = self.free_blocks[-n:]; del self.free_blocks[-n:]
return out
def free(self, blocks):
self.free_blocks.extend(blocks)
@dataclass
class Request:
prompt_ids: list[int]
blocks: list[int] = field(default_factory=list) # owned block IDs
slot_indices: list[int] = field(default_factory=list) # flat slot ids derived from blocks
cur_len: int = 0
output_ids: list[int] = field(default_factory=list)
@dataclass
class ForwardMeta:
positions: torch.Tensor # [T] int64
slot_mapping: torch.Tensor # [T] int64 — where to WRITE k,v in the pool
kv_indices: torch.Tensor # [T_kv] int64 — where to READ history
is_prefill: bool
2.4 The reservation helper
Before each forward call, the request must have enough slots reserved. reserve() grows req.blocks on demand and updates req.slot_indices:
def reserve(req: Request, alloc: BlockAllocator, target_len: int):
target_blocks = (target_len + alloc.block_size - 1) // alloc.block_size
if len(req.blocks) < target_blocks:
diff = target_blocks - len(req.blocks)
new_bs = alloc.alloc_blocks(diff)
req.blocks.extend(new_bs)
for b in new_bs:
base = b * alloc.block_size
req.slot_indices.extend(range(base, base + alloc.block_size))
Note the slot index convention: slot = block_id * block_size + within. This is the canonical "flat-token coordinate" we'll keep using through L3+.
2.5 What changes in Qwen3Attention.forward
Three lines change. The L1 contiguous write/read:
K_cache[layer, kv_write_offset:kv_write_offset+T] = k
V_cache[layer, kv_write_offset:kv_write_offset+T] = v
K = K_cache[layer, :kv_valid_len]
V = V_cache[layer, :kv_valid_len]
becomes the L2 paged version:
pool.set_kv(self.layer_idx, meta.slot_mapping, k, v)
K = pool.K[self.layer_idx][meta.kv_indices]
V = pool.V[self.layer_idx][meta.kv_indices]
Everything else (q_norm, k_norm, RoPE, GQA repeat, SDPA, o_proj) is unchanged.
2.6 Pool sizing
KV bytes per token = 2 (K&V) × H_kv × D × layers × dtype_bytes. For Qwen3-8B bf16:
2 * 8 * 128 * 36 * 2 = 147_456 bytes = 144 KiB / token
On a 32 GiB GPU with ~14 GiB free for KV after weights:
14 * 1024 / 144 ≈ 100k token slots possible.
L2 uses NUM_SLOTS = 8192, BLOCK_SIZE = 16 → 512 blocks, plenty for one request.
2.7 Smoke harness changes
The L2 driver builds ForwardMeta per call:
# prefill of N tokens
reserve(req, alloc, N)
slot = torch.tensor(req.slot_indices[:N], device='cuda', dtype=torch.long)
meta = ForwardMeta(
positions = torch.arange(N, device='cuda'),
slot_mapping = slot,
kv_indices = slot, # prefill: write set == read set
is_prefill = True,
)
logits = model(prompt_ids, pool, meta)
# decode step i
reserve(req, alloc, req.cur_len + 1)
pos = req.cur_len
slot_new = torch.tensor([req.slot_indices[pos]], device='cuda', dtype=torch.long)
kv_idx = torch.tensor(req.slot_indices[:pos+1], device='cuda', dtype=torch.long)
meta = ForwardMeta(
positions = torch.tensor([pos], device='cuda'),
slot_mapping = slot_new,
kv_indices = kv_idx,
is_prefill = False,
)
2.8 Pitfalls (real bugs we hit reviewing the implementation)
cur_len = T + 1off-by-one. After prefill,cur_lenshould beT, notT+1. The token we just argmax'd has not yet had its KV written — that happens on the next decode call.kv_idxwith extra brackets.torch.tensor([[...]])is 2-D and breaks the gather. Want a 1-D tensor.- Forgetting
self.layer_idx. EachQwen3Attentionmust stash its layer index, otherwise all 36 layers share slot 0 and you get pure noise. - Indexing dtype. PyTorch fancy indexing requires
int64. Usingint32silently misbehaves on CPU and raises on CUDA. - Allocator typo
self.free_blockvsself.free_blocks. Free-list grows wrong; only bites in L5 with eviction. - Class-name drift. The reference says
KvPool, dataclass isRequest. Earlier drafts hadKVPool/Req. Pick one and grep the codebase. - Dangling debug prints.
print("HEY", ...)inqwen3.pyattention forward. Strip before benchmarking.
2.9 Q&A from session
student question Why use a free-list allocator instead of a bitmap or slab allocator?
For a single GPU and ~thousands of blocks, a Python list of free IDs is fine — alloc/free is O(n) amortized but n is small. Real engines use ref-counted block tables (for L8 prefix sharing) or bump-allocator + LRU eviction (for L7+ with overflow). The free-list is the smallest thing that exposes the right interface.
student question
Why does set_kv not return anything?
It mutates the pool tensor in place. The model never reads the pool from a Python variable; the next attention call gathers via kv_indices. Fewer references = easier for L9 CUDA-graph capture.
2.10 Acceptance
L2 pass criteria
scripts/l2_smoke.py emits the same 20 token IDs as L1 / HF.