mini_sglang

L3 — Paged attention kernel

Block-table-aware attention. Stop gathering on the host; let the kernel walk the page table.

3.1 Why L2 hits a wall

L2's attention reads its history with a host-side gather:

K_full = pool.K[layer][meta.kv_indices]   # gather, shape [T_kv, H_kv, D]

Two problems:

  1. HBM traffic and intermediate allocation. For Qwen3-8B at 1k context, that's roughly 1024 × 8 × 128 × 2B × 36 × 2 ≈ 144 MiB shuffled per step. Linear in batch size and context.
  2. Can't represent multiple sequences in one kernel call. A flat kv_indices can't say "query q attends only to keys 3..14 of its sequence". You'd need expensive boolean masks.

The fix: keep storage paged but pass the kernel structural metadata describing where each sequence lives, then call the kernel once for the whole batch.

3.2 The three metadata tensors

Memorize these — every paged attention kernel takes them.

cu_seqlens_q : [num_seqs + 1]   prefix-sum of query lengths (starts at 0)
seq_lens_kv  : [num_seqs]       total cached length per sequence
block_table  : [num_seqs, max_blocks_per_seq]   per-seq page table (right-padded)

Examples:

scenariocu_seqlens_qseq_lens_kvblock_table
single-seq prefill of N[0, N][N][[b0, b1, …]]
single-seq decode at step i[0, 1][N+i+1][[b0, b1, …]]
two-seq mixed (A: decode, B: prefill chunk)[0, 1, 1+chunk_B][len_A, len_B][[A_blocks…], [B_blocks…]]

Query token at position q within sequence s attends to keys 0 .. seq_lens[s] − q_len[s] + q of sequence s only. Inter-sequence attention is structurally impossible — exactly what we want for batching.

3.3 Storage layout switch — block-tensor

Reshape each layer's K, V from [num_slots, H_kv, D] (token-flat) to [num_blocks, block_size, H_kv, D] (block-tensor). Same bytes, different addressing. The block dim is what kernels expect for paged reads.

The view trick that makes set_kv work

Pool tensor (per layer) slot_mapping (per token) shape: [num_blocks, block_size, H_kv, D] shape: [T] └──┬───┘ └────┬────┘ └─┬─┘ └┬┘ e.g. [5, 6, 7, 21, 22, 23] │ │ │ │ ← FLAT token indices which which slot how per- block INSIDE the many head block KV dim heads PyTorch fancy indexing rule: pool.K[layer][slot_mapping] selects along axis 0 = num_blocks → interprets [5,6,7] as BLOCK ids → wrong shape, garbage writes

Fix with a zero-copy reshape:

def set_kv(self, layer_id, slot_mapping, k, v):
    K = self.K[layer_id].view(-1, self.num_kv_heads, self.head_dim)  # [N*B, H_kv, D]
    V = self.V[layer_id].view(-1, self.num_kv_heads, self.head_dim)
    K[slot_mapping] = k
    V[slot_mapping] = v
Same bytes, two views: [N*B, H, D] ← view used for SCATTER (writes per token) indexed by slot_mapping[T] from BlockAllocator [N, B, H, D] ← original used for GATHER (reads per block) indexed by block_table[S, max_b] from req.blocks

The flat coordinate flat = block_id × block_size + within is exactly what BlockAllocator already produces in req.slot_indices. No translation math, just a view to match.

3.4 Two implementation paths

Path A — Pure PyTorch reference (no install)

For each sequence, do per-seq gather → SDPA. Same correctness as a real kernel, just no fusion. Use this to validate your metadata is shaped right.

def attn_paged_torch(q, K_cache, V_cache, meta, *, H_q, H_kv):
    outputs = []
    num_seqs = meta.cu_seqlens_q.numel() - 1
    for s in range(num_seqs):
        q_start = meta.cu_seqlens_q[s].item()
        q_end   = meta.cu_seqlens_q[s+1].item()
        q_len   = q_end - q_start
        kv_len  = meta.seq_lens_kv[s].item()

        n_blocks  = (kv_len + meta.block_size - 1) // meta.block_size
        block_ids = meta.block_table[s, :n_blocks]
        K_seq = K_cache[block_ids].reshape(-1, H_kv, K_cache.shape[-1])[:kv_len]
        V_seq = V_cache[block_ids].reshape(-1, H_kv, V_cache.shape[-1])[:kv_len]

        if H_q != H_kv:
            K_seq = K_seq.repeat_interleave(H_q // H_kv, dim=1)
            V_seq = V_seq.repeat_interleave(H_q // H_kv, dim=1)

        Q_s = q[q_start:q_end].transpose(0,1).unsqueeze(0)   # [1, H, q_len, D]
        K_s = K_seq.transpose(0,1).unsqueeze(0)
        V_s = V_seq.transpose(0,1).unsqueeze(0)
        out_s = F.scaled_dot_product_attention(
            Q_s, K_s, V_s,
            is_causal=bool(q_len == kv_len),
            dropout_p=0.0,
        )
        outputs.append(out_s.squeeze(0).transpose(0,1))      # [q_len, H, D]
    return torch.cat(outputs, dim=0)                          # [T_total_q, H, D]

Path B — flash-attn (recommended once installed)

from flash_attn import flash_attn_with_kvcache
out = flash_attn_with_kvcache(
    q,                                # [num_seqs, q_len_max, H_q, D]
    pool.K[layer], pool.V[layer],     # [num_blocks, block_size, H_kv, D]
    cache_seqlens = meta.seq_lens_kv, # [num_seqs] int32
    block_table   = meta.block_table, # [num_seqs, max_blocks] int32
    causal        = True,
    softmax_scale = 1.0 / math.sqrt(D),
)

Single fused kernel call for the whole batch. Use Path A first to validate, then swap.

3.5 Pitfalls

  1. block_table dtype. Must be int32 for flash-attn; int64 raises.
  2. Tensor scalars where Python ints are needed. cu_seqlens_q[s] and seq_lens_kv[s] are 0-D tensors; call .item() before slicing or arithmetic, and wrap q_len == kv_len in bool(...) for SDPA's is_causal.
  3. Forgetting the .view() in set_kv. Writes go into the block dim (see 3.3).
  4. is_causal generalisation. Per-sequence form is q_len_i == kv_len_i, the natural extension of L1's T == kv_valid_len.
  5. Output shape after the loop. attn_paged_torch returns [T, H, D] (token-major). To feed o_proj, just view(T, H*D)no extra transpose. (Real bug from this lesson — see 3.7.)

3.6 Smoke harness changes

# prefill
N    = prompt_ids.numel()
reserve(req, alloc, N)
slot = torch.tensor(req.slot_indices[:N], device='cuda', dtype=torch.long)
meta = ForwardMeta(
    positions    = torch.arange(N, device='cuda'),
    slot_mapping = slot,
    cu_seqlens_q = torch.tensor([0, N], device='cuda', dtype=torch.int32),
    seq_lens_kv  = torch.tensor([N],    device='cuda', dtype=torch.int32),
    block_table  = torch.tensor([req.blocks], device='cuda', dtype=torch.int32),
    block_size   = alloc.block_size,
    is_prefill   = True,
)
logits = model(prompt_ids, pool, meta)

# decode step i
reserve(req, alloc, req.cur_len + 1)
pos = req.cur_len
slot_new = torch.tensor([req.slot_indices[pos]], device='cuda', dtype=torch.long)
meta = ForwardMeta(
    positions    = torch.tensor([pos], device='cuda'),
    slot_mapping = slot_new,
    cu_seqlens_q = torch.tensor([0, 1],     device='cuda', dtype=torch.int32),
    seq_lens_kv  = torch.tensor([pos + 1],  device='cuda', dtype=torch.int32),
    block_table  = torch.tensor([req.blocks], device='cuda', dtype=torch.int32),
    block_size   = alloc.block_size,
    is_prefill   = False,
)

3.7 Q&A from session

student question Show me the shape with diagrams — I'm not seeing why .view(-1, H_kv, D) is right.

See 3.3 above. Two diagrams: (1) why fancy indexing on the un-viewed tensor selects whole blocks, and (2) the same bytes presented as either [N*B, H, D] for writes or [N, B, H, D] for reads.

student question L3 runs end-to-end but token output is garbage ("OwO, the user is asking…"). What's wrong?

The diagnostic: divergence at position 0 means prefill itself is wrong, not decode. Numeric drift gives "Paris" → "Paris" → some-other-word; total non-sequitur points to a layer-input scramble.

Cause: an extra transpose(0, 1) before the flatten in Qwen3Attention.forward:

# BUG
out = out.transpose(0, 1).contiguous().view(T, H_q * D)
# FIX (out is already [T, H, D] from attn_paged_torch)
out = out.contiguous().view(T, H_q * D)

The bad path: shape [T, H, D] → transpose → [H, T, D] → contiguous (memory now head-major) → view(T, H*D) reinterprets the head-major buffer as token-major. Each "row" fed to o_proj is one head's T tokens worth of activations, not one token's H heads. Compounds across 36 layers → "OwO".

student question Why have both Path A and Path B? Just pick one.

Path A is the spec the kernel implements. If your metadata is wrong, both will be wrong, but Path A's bug is in your own loop, easy to print. Path B's bug is inside a kernel — much harder to read. Build Path A first, verify the smoke test passes, then swap to Path B as a one-line substitution.

3.8 Acceptance

L3 pass criteria scripts/l3_smoke.py emits the same 20 token IDs as L1 / L2. Optionally Path B (flash-attn) produces identical output.

3.9 What L3 unlocks

The metadata you just built (cu_seqlens_q, seq_lens_kv, block_table) is exactly what L5's scheduler will produce per step from a list of in-flight requests. That's the entire scheduler in one sentence. The hard data-structure work is done; L4 (sampler) and L5 (scheduler) become small surgical additions on top of this skeleton.

← L2 Home