mini_sglang

L2 — Paged KV cache

Replace contiguous KV with a pool + block allocator. The data structure that makes batching possible.

2.1 Why paging

The L1 KV is one big contiguous tensor per layer of shape [max_seq_len, H_kv, D]. That works for one request, but:

A paged pool fixes all three. KV memory is sliced into fixed-size blocks (16 tokens each). A request that has 73 tokens of history holds ceil(73/16) = 5 blocks; when it finishes, those 5 blocks return to a free list. New requests grab any free block, anywhere.

2.2 Storage choice: token-flat vs block-tensor

Two equivalent ways to lay out the pool:

token-flat (this lesson)block-tensor (L3+)
per-layer shape[num_slots, H_kv, D][num_blocks, block_size, H_kv, D]
indexed byflat slot id(block_id, within)
writesdirect: K[slot_mapping] = kflatten via .view(-1, H_kv, D)
kernel-friendly?requires gather to read historykernel walks block_table directly

Both layouts use the same number of bytes. We'll use token-flat in L2 (simpler), then switch to block-tensor in L3 when we introduce a paged-aware kernel.

2.3 The three classes

class KvPool:
    """Per-layer K, V tensors. Knows nothing about which slots belong to whom."""
    def __init__(self, num_layers, num_slots, num_kv_heads, head_dim, dtype, device):
        self.K = [torch.empty(num_slots, num_kv_heads, head_dim, dtype=dtype, device=device)
                  for _ in range(num_layers)]
        self.V = [torch.empty_like(self.K[0]) for _ in range(num_layers)]
    def set_kv(self, layer, slot_mapping, k, v):
        self.K[layer][slot_mapping] = k
        self.V[layer][slot_mapping] = v

class BlockAllocator:
    """Free-list of block IDs. One global allocator across all requests."""
    def __init__(self, num_slots, block_size=16):
        assert num_slots % block_size == 0
        self.block_size = block_size
        self.num_blocks = num_slots // block_size
        self.free_blocks = list(range(self.num_blocks))
    def alloc_blocks(self, n):
        if n > len(self.free_blocks):
            raise RuntimeError("OOM")
        out = self.free_blocks[-n:]; del self.free_blocks[-n:]
        return out
    def free(self, blocks):
        self.free_blocks.extend(blocks)

@dataclass
class Request:
    prompt_ids: list[int]
    blocks:        list[int] = field(default_factory=list)   # owned block IDs
    slot_indices:  list[int] = field(default_factory=list)   # flat slot ids derived from blocks
    cur_len:       int = 0
    output_ids:    list[int] = field(default_factory=list)

@dataclass
class ForwardMeta:
    positions:    torch.Tensor   # [T] int64
    slot_mapping: torch.Tensor   # [T] int64  — where to WRITE k,v in the pool
    kv_indices:   torch.Tensor   # [T_kv] int64 — where to READ history
    is_prefill:   bool

2.4 The reservation helper

Before each forward call, the request must have enough slots reserved. reserve() grows req.blocks on demand and updates req.slot_indices:

def reserve(req: Request, alloc: BlockAllocator, target_len: int):
    target_blocks = (target_len + alloc.block_size - 1) // alloc.block_size
    if len(req.blocks) < target_blocks:
        diff   = target_blocks - len(req.blocks)
        new_bs = alloc.alloc_blocks(diff)
        req.blocks.extend(new_bs)
        for b in new_bs:
            base = b * alloc.block_size
            req.slot_indices.extend(range(base, base + alloc.block_size))

Note the slot index convention: slot = block_id * block_size + within. This is the canonical "flat-token coordinate" we'll keep using through L3+.

2.5 What changes in Qwen3Attention.forward

Three lines change. The L1 contiguous write/read:

K_cache[layer, kv_write_offset:kv_write_offset+T] = k
V_cache[layer, kv_write_offset:kv_write_offset+T] = v
K = K_cache[layer, :kv_valid_len]
V = V_cache[layer, :kv_valid_len]

becomes the L2 paged version:

pool.set_kv(self.layer_idx, meta.slot_mapping, k, v)
K = pool.K[self.layer_idx][meta.kv_indices]
V = pool.V[self.layer_idx][meta.kv_indices]

Everything else (q_norm, k_norm, RoPE, GQA repeat, SDPA, o_proj) is unchanged.

2.6 Pool sizing

KV bytes per token = 2 (K&V) × H_kv × D × layers × dtype_bytes. For Qwen3-8B bf16:

2 * 8 * 128 * 36 * 2  =  147_456 bytes  =  144 KiB / token

On a 32 GiB GPU with ~14 GiB free for KV after weights:
14 * 1024 / 144  ≈  100k token slots possible.

L2 uses NUM_SLOTS = 8192, BLOCK_SIZE = 16  → 512 blocks, plenty for one request.

2.7 Smoke harness changes

The L2 driver builds ForwardMeta per call:

# prefill of N tokens
reserve(req, alloc, N)
slot = torch.tensor(req.slot_indices[:N], device='cuda', dtype=torch.long)
meta = ForwardMeta(
    positions    = torch.arange(N, device='cuda'),
    slot_mapping = slot,
    kv_indices   = slot,            # prefill: write set == read set
    is_prefill   = True,
)
logits = model(prompt_ids, pool, meta)

# decode step i
reserve(req, alloc, req.cur_len + 1)
pos      = req.cur_len
slot_new = torch.tensor([req.slot_indices[pos]], device='cuda', dtype=torch.long)
kv_idx   = torch.tensor(req.slot_indices[:pos+1], device='cuda', dtype=torch.long)
meta = ForwardMeta(
    positions    = torch.tensor([pos], device='cuda'),
    slot_mapping = slot_new,
    kv_indices   = kv_idx,
    is_prefill   = False,
)

2.8 Pitfalls (real bugs we hit reviewing the implementation)

  1. cur_len = T + 1 off-by-one. After prefill, cur_len should be T, not T+1. The token we just argmax'd has not yet had its KV written — that happens on the next decode call.
  2. kv_idx with extra brackets. torch.tensor([[...]]) is 2-D and breaks the gather. Want a 1-D tensor.
  3. Forgetting self.layer_idx. Each Qwen3Attention must stash its layer index, otherwise all 36 layers share slot 0 and you get pure noise.
  4. Indexing dtype. PyTorch fancy indexing requires int64. Using int32 silently misbehaves on CPU and raises on CUDA.
  5. Allocator typo self.free_block vs self.free_blocks. Free-list grows wrong; only bites in L5 with eviction.
  6. Class-name drift. The reference says KvPool, dataclass is Request. Earlier drafts had KVPool/Req. Pick one and grep the codebase.
  7. Dangling debug prints. print("HEY", ...) in qwen3.py attention forward. Strip before benchmarking.

2.9 Q&A from session

student question Why use a free-list allocator instead of a bitmap or slab allocator?

For a single GPU and ~thousands of blocks, a Python list of free IDs is fine — alloc/free is O(n) amortized but n is small. Real engines use ref-counted block tables (for L8 prefix sharing) or bump-allocator + LRU eviction (for L7+ with overflow). The free-list is the smallest thing that exposes the right interface.

student question Why does set_kv not return anything?

It mutates the pool tensor in place. The model never reads the pool from a Python variable; the next attention call gathers via kv_indices. Fewer references = easier for L9 CUDA-graph capture.

2.10 Acceptance

L2 pass criteria scripts/l2_smoke.py emits the same 20 token IDs as L1 / HF.
← L1 L3 — Paged attention →