mini_sglang

L8 — Radix prefix cache

The defining sglang trick. When two requests share a prefix, reuse the prefix's KV instead of re-computing it.

8.1 The win, quantified

Chat workloads have a system prompt at the start of every request. Without prefix caching every request re-prefills those 500 tokens. With prefix caching they're prefilled once; subsequent requests bump the refcount and start from the cached KV.

Verified on Qwen3-8B with a 97-token shared prefix + 5 unmatched tail tokens:

WITH cache:   ext run = 25 q tokens in 20 steps   ← only 6 unmatched tail tokens prefilled
                                                    + 19 decodes × 1 token = 25 total
BASELINE:     ext run = 121 q tokens in 20 steps  ← full 102-token re-prefill + 19 decodes

✓ cache saved 96 q tokens (79.3% of baseline prefill compute)
✓ cached output bit-identical to baseline

8.2 The data structure: radix tree of token sequences

Compressed trie keyed by token IDs. Each node = (a run of tokens, the KV blocks for those tokens, dict of children keyed by first-token, parent, LRU timestamp).

root (tokens=[], blocks=[]) │ ┌─────────────┴─────────────┐ │ │ [12, 45, 17, ...] [99, 3, 8, ...] 500 tokens, 32 blocks 200 tokens, 13 blocks "system prompt" "different system prompt" │ ┌─────────┴─────────┐ │ │ [42, 7, ...] [88, 1, ...] user A's turn 1 user B's turn 1

Three operations: match, insert, evict.

8.3 The central invariant: block refcounting

Every block has a refcount = (active requests holding it) + (radix-tree nodes referencing it).

class BlockAllocator:
    refcount: list[int]                  # NEW

    def alloc_blocks(self, n):
        ...
        for b in out: self.refcount[b] = 1
        return out

    def incref(self, blocks):
        for b in blocks: self.refcount[b] += 1

    def decref(self, blocks):
        freed = 0
        for b in blocks:
            self.refcount[b] -= 1
            if self.refcount[b] == 0:
                self.free_blocks.append(b)
                freed += 1
        return freed

Replace every alloc.free(...) in the codebase with alloc.decref(...). This is the most invasive part of L8 — when the cache also holds a ref, decref keeps the block alive.

8.4 Why a radix tree (not a hashmap)

  1. Sub-prefix matching. If you cached "A B C D" and a new request asks "A B C X", you reuse "A B C". A hashmap matches whole keys; a radix tree matches the longest prefix.
  2. Incremental insertion + structural sharing. A new request with prompt "A B C E" extends an existing "A B C" node without re-storing the prefix.

8.5 The match algorithm (1 _common_prefix_len call per descent)

In a radix tree no two siblings share a first token (otherwise they'd be merged into a common parent). So children are keyed by first-token in a dict; lookup is O(1):

def match(self, token_ids):
    node = self.root
    out_blocks, matched = [], 0
    while matched < len(token_ids):
        first = token_ids[matched]
        child = node.children.get(first)            # O(1) — radix property
        if child is None: break

        cp = self._common_prefix_len(child.tokens, token_ids[matched:])
        child.last_access_time = time.time()

        if cp == len(child.tokens):
            out_blocks.extend(child.blocks)
            matched += cp
            node = child                            # full child match — descend
        else:
            blk = cp // self.alloc.block_size
            out_blocks.extend(child.blocks[:blk])
            matched += cp
            break                                   # partial match — done
    if out_blocks:
        self.alloc.incref(out_blocks)               # caller takes ownership
    return out_blocks, matched

Per descent: exactly one _common_prefix_len call. The original implementation had three (self check + per-child loop + redundant recursion). For deep trees + wide siblings, this is the right asymptotics.

8.6 Block alignment and the rounding direction

The cache only stores block-aligned prefixes. _common_prefix_len rounds down:

bs = self.alloc.block_size
return bs * (i // bs)             # ROUND DOWN to block boundary
silent-corruption pitfall Writing bs * ((i + 1) // bs) by mistake rounds UP and claims tokens match that actually diverged. Demonstrated in L8 review: match([1,2,3,99]) against cached [1,2,3,5] returned ([5,7], 4) instead of ([5], 2) — caller silently uses block 7 holding the wrong token.

8.7 _split — the partition operation

When a new prompt partially matches a cached node, split it. The node becomes the matched head; a new split_node holds the tail (and inherits the original's children, with parent pointers updated):

def _split(self, node, k):                                # k must be block-aligned
    blk_idx = k // self.alloc.block_size
    head_tokens,  tail_tokens  = node.tokens[:k],  node.tokens[k:]
    head_blocks,  tail_blocks  = node.blocks[:blk_idx], node.blocks[blk_idx:]

    split_node = Node(
        tokens=tail_tokens, blocks=tail_blocks,
        children=node.children,                           # tail takes the subtree
        parent=node, ...
    )
    for grandchild in split_node.children.values():
        grandchild.parent = split_node                    # ← critical: re-parent

    node.tokens   = head_tokens
    node.blocks   = head_blocks
    node.children = {split_node.tokens[0]: split_node}    # new dict, not the old one
    return node                                           # the head
aliasing pitfall Sharing children dict references between head and tail, or setting node.children = [split_node] (a list instead of a dict), or forgetting to update grandchildren's parent pointers — each is a silent topology corruption that the L8 review caught in turn.

8.8 Scheduler integration

Three changes:

# 1. add_request: lookup + populate slot_indices from cached blocks
def add_request(self, req):
    if self.cache:
        cached_blocks, matched_len = self.cache.match(req.prompt_ids.tolist())
        req.blocks   = list(cached_blocks)
        req.cur_len  = matched_len
        bs = self.alloc.block_size
        for b in cached_blocks:                               # ← critical
            req.slot_indices.extend(range(b * bs, b * bs + bs))
    self.waiting.append(req)

# 2. on finish: insert full token sequence into cache, then decref request's ref
if next_id == self.eos_id or len(req.output_ids) >= req.max_tokens:
    finished.append(req.id)
    if self.cache:
        self.cache.insert(req.prompt_ids.tolist() + req.output_ids, req.blocks)
    self.alloc.decref(req.blocks)                              # not alloc.free

# 3. on OOM: try eviction before raising
def alloc_with_evict(n):
    needed = n - len(alloc.free_blocks)
    if needed > 0:
        cache.evict(needed)
    return alloc.alloc_blocks(n)
subtle bug Forgetting to populate req.slot_indices from cached_blocks leaves the list empty while cur_len > 0. Then build_meta later does r.slot_indices[p] for p = cur_len, cur_len+1, ...IndexError. This bit our smoke on B.2; the fix is the inner for b in cached_blocks: req.slot_indices.extend(...).

8.9 Eviction: descending-level, LRU-within-level

def evict(self, n_blocks):
    levels = sorted(self.node_by_level.keys())
    while n_blocks > 0 and levels:
        level = levels.pop()                         # deepest first
        nodes = sorted(self.node_by_level[level], key=lambda n: n.last_access_time)
        if not nodes: continue
        for i, n in enumerate(nodes):
            freed = self.alloc.decref(n.blocks)
            first = n.tokens[0]
            if first in n.parent.children:
                del n.parent.children[first]         # disconnect from tree
            n_blocks -= freed
            if n_blocks <= 0: break
        self.node_by_level[level] = nodes[i+1:]

Two design choices:

8.10 Q&A from session

student question Eviction worries me — when I wipe all leaves, the parents become leaves but their level field doesn't change. Won't they be orphaned in node_by_level?

No — and you were right, I was wrong initially. The level field stays at L-1; the node is still indexed in node_by_level[L-1]. The outer loop pops levels descending, so after exhausting level L it visits L-1 and finds those nodes (now leaves). Promotion is implicit.

student question Why do we call _common_prefix_len three times to visit a child?

The original recursive form did: 1 against current node + N against each child to pick best + 1 redundant inside the recursion entering the chosen child. The radix-tree-correct form is one dict lookup + one _common_prefix_len per descent. See §8.5.

8.11 Pitfalls table (real ones we hit)

pitfallsymptom
round UP in _common_prefix_lenmatch claims more tokens than actually agree → caller uses block holding wrong tokens → silent KV corruption
set node.child = [...] instead of node.children = {...}typo creates a new dataclass attribute; original .children untouched; subtree silently orphaned
list-comprehension on dict in evictiterates keys not values; filter is no-op; node never removed → next match hits a recycled block
insert slices blocks by token countblocks[cp_len:] when cp_len is tokens; should be blocks[cp_len // bs:]
missing incref on match/insertcache holds a ref the allocator doesn't know about; evict frees a block another request is reading → garbage
insert silently drops "redundant" caller blockscaller's refcount on the dropped block leaks until request finishes; document the contract or have insert return them
scheduler forgets to populate slot_indices from cached blocksbuild_meta IndexErrors at the first cached-position lookup
_split doesn't re-parent transferred grandchildrenevict walks n.parent up the wrong chain → deletes the wrong dict entry

8.12 What we skipped vs production

featurereal sglangus
Per-token (un-aligned) cachingyesblock-aligned only
Hierarchical (CPU offload / HiCache)yesno
Cache-aware scheduling (prioritize high-hit requests)yesFCFS
Hot-set protection in evictionyesbasic descending-level LRU
Cross-tenant isolationconfigurablesingle tenant

MVP is ~250 LOC of new code. Captures roughly 80% of the throughput win for chat workloads.

8.13 Acceptance (verified)

L8 PASS
← L7 Home