L8 — Radix prefix cache
8.1 The win, quantified
Chat workloads have a system prompt at the start of every request. Without prefix caching every request re-prefills those 500 tokens. With prefix caching they're prefilled once; subsequent requests bump the refcount and start from the cached KV.
Verified on Qwen3-8B with a 97-token shared prefix + 5 unmatched tail tokens:
WITH cache: ext run = 25 q tokens in 20 steps ← only 6 unmatched tail tokens prefilled
+ 19 decodes × 1 token = 25 total
BASELINE: ext run = 121 q tokens in 20 steps ← full 102-token re-prefill + 19 decodes
✓ cache saved 96 q tokens (79.3% of baseline prefill compute)
✓ cached output bit-identical to baseline
8.2 The data structure: radix tree of token sequences
Compressed trie keyed by token IDs. Each node = (a run of tokens, the KV blocks for those tokens, dict of children keyed by first-token, parent, LRU timestamp).
Three operations: match, insert, evict.
8.3 The central invariant: block refcounting
Every block has a refcount = (active requests holding it) + (radix-tree nodes referencing it).
class BlockAllocator:
refcount: list[int] # NEW
def alloc_blocks(self, n):
...
for b in out: self.refcount[b] = 1
return out
def incref(self, blocks):
for b in blocks: self.refcount[b] += 1
def decref(self, blocks):
freed = 0
for b in blocks:
self.refcount[b] -= 1
if self.refcount[b] == 0:
self.free_blocks.append(b)
freed += 1
return freed
Replace every alloc.free(...) in the codebase with alloc.decref(...). This is the most invasive part of L8 — when the cache also holds a ref, decref keeps the block alive.
8.4 Why a radix tree (not a hashmap)
- Sub-prefix matching. If you cached "A B C D" and a new request asks "A B C X", you reuse "A B C". A hashmap matches whole keys; a radix tree matches the longest prefix.
- Incremental insertion + structural sharing. A new request with prompt "A B C E" extends an existing "A B C" node without re-storing the prefix.
8.5 The match algorithm (1 _common_prefix_len call per descent)
In a radix tree no two siblings share a first token (otherwise they'd be merged into a common parent). So children are keyed by first-token in a dict; lookup is O(1):
def match(self, token_ids):
node = self.root
out_blocks, matched = [], 0
while matched < len(token_ids):
first = token_ids[matched]
child = node.children.get(first) # O(1) — radix property
if child is None: break
cp = self._common_prefix_len(child.tokens, token_ids[matched:])
child.last_access_time = time.time()
if cp == len(child.tokens):
out_blocks.extend(child.blocks)
matched += cp
node = child # full child match — descend
else:
blk = cp // self.alloc.block_size
out_blocks.extend(child.blocks[:blk])
matched += cp
break # partial match — done
if out_blocks:
self.alloc.incref(out_blocks) # caller takes ownership
return out_blocks, matched
Per descent: exactly one _common_prefix_len call. The original implementation had three (self check + per-child loop + redundant recursion). For deep trees + wide siblings, this is the right asymptotics.
8.6 Block alignment and the rounding direction
The cache only stores block-aligned prefixes. _common_prefix_len rounds down:
bs = self.alloc.block_size
return bs * (i // bs) # ROUND DOWN to block boundary
silent-corruption pitfall Writingbs * ((i + 1) // bs)by mistake rounds UP and claims tokens match that actually diverged. Demonstrated in L8 review:match([1,2,3,99])against cached[1,2,3,5]returned([5,7], 4)instead of([5], 2)— caller silently uses block 7 holding the wrong token.
8.7 _split — the partition operation
When a new prompt partially matches a cached node, split it. The node becomes the matched head; a new split_node holds the tail (and inherits the original's children, with parent pointers updated):
def _split(self, node, k): # k must be block-aligned
blk_idx = k // self.alloc.block_size
head_tokens, tail_tokens = node.tokens[:k], node.tokens[k:]
head_blocks, tail_blocks = node.blocks[:blk_idx], node.blocks[blk_idx:]
split_node = Node(
tokens=tail_tokens, blocks=tail_blocks,
children=node.children, # tail takes the subtree
parent=node, ...
)
for grandchild in split_node.children.values():
grandchild.parent = split_node # ← critical: re-parent
node.tokens = head_tokens
node.blocks = head_blocks
node.children = {split_node.tokens[0]: split_node} # new dict, not the old one
return node # the head
aliasing pitfall Sharingchildrendict references between head and tail, or settingnode.children = [split_node](a list instead of a dict), or forgetting to update grandchildren's parent pointers — each is a silent topology corruption that the L8 review caught in turn.
8.8 Scheduler integration
Three changes:
# 1. add_request: lookup + populate slot_indices from cached blocks
def add_request(self, req):
if self.cache:
cached_blocks, matched_len = self.cache.match(req.prompt_ids.tolist())
req.blocks = list(cached_blocks)
req.cur_len = matched_len
bs = self.alloc.block_size
for b in cached_blocks: # ← critical
req.slot_indices.extend(range(b * bs, b * bs + bs))
self.waiting.append(req)
# 2. on finish: insert full token sequence into cache, then decref request's ref
if next_id == self.eos_id or len(req.output_ids) >= req.max_tokens:
finished.append(req.id)
if self.cache:
self.cache.insert(req.prompt_ids.tolist() + req.output_ids, req.blocks)
self.alloc.decref(req.blocks) # not alloc.free
# 3. on OOM: try eviction before raising
def alloc_with_evict(n):
needed = n - len(alloc.free_blocks)
if needed > 0:
cache.evict(needed)
return alloc.alloc_blocks(n)
subtle bug Forgetting to populatereq.slot_indicesfromcached_blocksleaves the list empty whilecur_len > 0. Thenbuild_metalater doesr.slot_indices[p]forp = cur_len, cur_len+1, ...→IndexError. This bit our smoke on B.2; the fix is the innerfor b in cached_blocks: req.slot_indices.extend(...).
8.9 Eviction: descending-level, LRU-within-level
def evict(self, n_blocks):
levels = sorted(self.node_by_level.keys())
while n_blocks > 0 and levels:
level = levels.pop() # deepest first
nodes = sorted(self.node_by_level[level], key=lambda n: n.last_access_time)
if not nodes: continue
for i, n in enumerate(nodes):
freed = self.alloc.decref(n.blocks)
first = n.tokens[0]
if first in n.parent.children:
del n.parent.children[first] # disconnect from tree
n_blocks -= freed
if n_blocks <= 0: break
self.node_by_level[level] = nodes[i+1:]
Two design choices:
- Descending level. Deepest level is always leaves (by tree definition). After evicting all leaves at level L, level L-1 nodes become the new leaves — their
levelfield doesn't change but they ARE leaves now. The next outer iteration processes them. No need to re-index. - Not strict LRU. Shallow nodes (system prompts, shared bases) are protected; deep leaves (per-request tails) evict first. Real sglang does the same — favours keeping shorter, more-shareable prefixes.
8.10 Q&A from session
student question
Eviction worries me — when I wipe all leaves, the parents become leaves but their level field doesn't change. Won't they be orphaned in node_by_level?
No — and you were right, I was wrong initially. The level field stays at L-1; the node is still indexed in node_by_level[L-1]. The outer loop pops levels descending, so after exhausting level L it visits L-1 and finds those nodes (now leaves). Promotion is implicit.
student question
Why do we call _common_prefix_len three times to visit a child?
The original recursive form did: 1 against current node + N against each child to pick best + 1 redundant inside the recursion entering the chosen child. The radix-tree-correct form is one dict lookup + one _common_prefix_len per descent. See §8.5.
8.11 Pitfalls table (real ones we hit)
| pitfall | symptom |
|---|---|
round UP in _common_prefix_len | match claims more tokens than actually agree → caller uses block holding wrong tokens → silent KV corruption |
set node.child = [...] instead of node.children = {...} | typo creates a new dataclass attribute; original .children untouched; subtree silently orphaned |
| list-comprehension on dict in evict | iterates keys not values; filter is no-op; node never removed → next match hits a recycled block |
| insert slices blocks by token count | blocks[cp_len:] when cp_len is tokens; should be blocks[cp_len // bs:] |
| missing incref on match/insert | cache holds a ref the allocator doesn't know about; evict frees a block another request is reading → garbage |
| insert silently drops "redundant" caller blocks | caller's refcount on the dropped block leaks until request finishes; document the contract or have insert return them |
scheduler forgets to populate slot_indices from cached blocks | build_meta IndexErrors at the first cached-position lookup |
| _split doesn't re-parent transferred grandchildren | evict walks n.parent up the wrong chain → deletes the wrong dict entry |
8.12 What we skipped vs production
| feature | real sglang | us |
|---|---|---|
| Per-token (un-aligned) caching | yes | block-aligned only |
| Hierarchical (CPU offload / HiCache) | yes | no |
| Cache-aware scheduling (prioritize high-hit requests) | yes | FCFS |
| Hot-set protection in eviction | yes | basic descending-level LRU |
| Cross-tenant isolation | configurable | single tenant |
MVP is ~250 LOC of new code. Captures roughly 80% of the throughput win for chat workloads.
8.13 Acceptance (verified)
L8 PASS
- Phase A (units): empty match, exact match, partial-at-odd-token (regression), insert-with-split + 4 post-split matches, no false hits, eviction actually frees.
- Phase B (e2e): cached run uses 25 q tokens vs 121 baseline (79.3% saved) for the same prompt. Output bit-identical.
- Phase C: unrelated prompt → matched_len=0.