L3 — Paged attention kernel
3.1 Why L2 hits a wall
L2's attention reads its history with a host-side gather:
K_full = pool.K[layer][meta.kv_indices] # gather, shape [T_kv, H_kv, D]
Two problems:
- HBM traffic and intermediate allocation. For Qwen3-8B at 1k context, that's roughly
1024 × 8 × 128 × 2B × 36 × 2 ≈ 144 MiBshuffled per step. Linear in batch size and context. - Can't represent multiple sequences in one kernel call. A flat
kv_indicescan't say "query q attends only to keys 3..14 of its sequence". You'd need expensive boolean masks.
The fix: keep storage paged but pass the kernel structural metadata describing where each sequence lives, then call the kernel once for the whole batch.
3.2 The three metadata tensors
Memorize these — every paged attention kernel takes them.
cu_seqlens_q : [num_seqs + 1] prefix-sum of query lengths (starts at 0)
seq_lens_kv : [num_seqs] total cached length per sequence
block_table : [num_seqs, max_blocks_per_seq] per-seq page table (right-padded)
Examples:
| scenario | cu_seqlens_q | seq_lens_kv | block_table |
|---|---|---|---|
| single-seq prefill of N | [0, N] | [N] | [[b0, b1, …]] |
| single-seq decode at step i | [0, 1] | [N+i+1] | [[b0, b1, …]] |
| two-seq mixed (A: decode, B: prefill chunk) | [0, 1, 1+chunk_B] | [len_A, len_B] | [[A_blocks…], [B_blocks…]] |
Query token at position q within sequence s attends to keys 0 .. seq_lens[s] − q_len[s] + q of sequence s only. Inter-sequence attention is structurally impossible — exactly what we want for batching.
3.3 Storage layout switch — block-tensor
Reshape each layer's K, V from [num_slots, H_kv, D] (token-flat) to [num_blocks, block_size, H_kv, D] (block-tensor). Same bytes, different addressing. The block dim is what kernels expect for paged reads.
The view trick that makes set_kv work
Fix with a zero-copy reshape:
def set_kv(self, layer_id, slot_mapping, k, v):
K = self.K[layer_id].view(-1, self.num_kv_heads, self.head_dim) # [N*B, H_kv, D]
V = self.V[layer_id].view(-1, self.num_kv_heads, self.head_dim)
K[slot_mapping] = k
V[slot_mapping] = v
The flat coordinate flat = block_id × block_size + within is exactly what BlockAllocator already produces in req.slot_indices. No translation math, just a view to match.
3.4 Two implementation paths
Path A — Pure PyTorch reference (no install)
For each sequence, do per-seq gather → SDPA. Same correctness as a real kernel, just no fusion. Use this to validate your metadata is shaped right.
def attn_paged_torch(q, K_cache, V_cache, meta, *, H_q, H_kv):
outputs = []
num_seqs = meta.cu_seqlens_q.numel() - 1
for s in range(num_seqs):
q_start = meta.cu_seqlens_q[s].item()
q_end = meta.cu_seqlens_q[s+1].item()
q_len = q_end - q_start
kv_len = meta.seq_lens_kv[s].item()
n_blocks = (kv_len + meta.block_size - 1) // meta.block_size
block_ids = meta.block_table[s, :n_blocks]
K_seq = K_cache[block_ids].reshape(-1, H_kv, K_cache.shape[-1])[:kv_len]
V_seq = V_cache[block_ids].reshape(-1, H_kv, V_cache.shape[-1])[:kv_len]
if H_q != H_kv:
K_seq = K_seq.repeat_interleave(H_q // H_kv, dim=1)
V_seq = V_seq.repeat_interleave(H_q // H_kv, dim=1)
Q_s = q[q_start:q_end].transpose(0,1).unsqueeze(0) # [1, H, q_len, D]
K_s = K_seq.transpose(0,1).unsqueeze(0)
V_s = V_seq.transpose(0,1).unsqueeze(0)
out_s = F.scaled_dot_product_attention(
Q_s, K_s, V_s,
is_causal=bool(q_len == kv_len),
dropout_p=0.0,
)
outputs.append(out_s.squeeze(0).transpose(0,1)) # [q_len, H, D]
return torch.cat(outputs, dim=0) # [T_total_q, H, D]
Path B — flash-attn (recommended once installed)
from flash_attn import flash_attn_with_kvcache
out = flash_attn_with_kvcache(
q, # [num_seqs, q_len_max, H_q, D]
pool.K[layer], pool.V[layer], # [num_blocks, block_size, H_kv, D]
cache_seqlens = meta.seq_lens_kv, # [num_seqs] int32
block_table = meta.block_table, # [num_seqs, max_blocks] int32
causal = True,
softmax_scale = 1.0 / math.sqrt(D),
)
Single fused kernel call for the whole batch. Use Path A first to validate, then swap.
3.5 Pitfalls
block_tabledtype. Must beint32for flash-attn;int64raises.- Tensor scalars where Python ints are needed.
cu_seqlens_q[s]andseq_lens_kv[s]are 0-D tensors; call.item()before slicing or arithmetic, and wrapq_len == kv_leninbool(...)for SDPA'sis_causal. - Forgetting the
.view()inset_kv. Writes go into the block dim (see 3.3). is_causalgeneralisation. Per-sequence form isq_len_i == kv_len_i, the natural extension of L1'sT == kv_valid_len.- Output shape after the loop.
attn_paged_torchreturns[T, H, D](token-major). To feedo_proj, justview(T, H*D)— no extra transpose. (Real bug from this lesson — see 3.7.)
3.6 Smoke harness changes
# prefill
N = prompt_ids.numel()
reserve(req, alloc, N)
slot = torch.tensor(req.slot_indices[:N], device='cuda', dtype=torch.long)
meta = ForwardMeta(
positions = torch.arange(N, device='cuda'),
slot_mapping = slot,
cu_seqlens_q = torch.tensor([0, N], device='cuda', dtype=torch.int32),
seq_lens_kv = torch.tensor([N], device='cuda', dtype=torch.int32),
block_table = torch.tensor([req.blocks], device='cuda', dtype=torch.int32),
block_size = alloc.block_size,
is_prefill = True,
)
logits = model(prompt_ids, pool, meta)
# decode step i
reserve(req, alloc, req.cur_len + 1)
pos = req.cur_len
slot_new = torch.tensor([req.slot_indices[pos]], device='cuda', dtype=torch.long)
meta = ForwardMeta(
positions = torch.tensor([pos], device='cuda'),
slot_mapping = slot_new,
cu_seqlens_q = torch.tensor([0, 1], device='cuda', dtype=torch.int32),
seq_lens_kv = torch.tensor([pos + 1], device='cuda', dtype=torch.int32),
block_table = torch.tensor([req.blocks], device='cuda', dtype=torch.int32),
block_size = alloc.block_size,
is_prefill = False,
)
3.7 Q&A from session
student question
Show me the shape with diagrams — I'm not seeing why .view(-1, H_kv, D) is right.
See 3.3 above. Two diagrams: (1) why fancy indexing on the un-viewed tensor selects whole blocks, and (2) the same bytes presented as either [N*B, H, D] for writes or [N, B, H, D] for reads.
student question L3 runs end-to-end but token output is garbage ("OwO, the user is asking…"). What's wrong?
The diagnostic: divergence at position 0 means prefill itself is wrong, not decode. Numeric drift gives "Paris" → "Paris" → some-other-word; total non-sequitur points to a layer-input scramble.
Cause: an extra transpose(0, 1) before the flatten in Qwen3Attention.forward:
# BUG
out = out.transpose(0, 1).contiguous().view(T, H_q * D)
# FIX (out is already [T, H, D] from attn_paged_torch)
out = out.contiguous().view(T, H_q * D)
The bad path: shape [T, H, D] → transpose → [H, T, D] → contiguous (memory now head-major) → view(T, H*D) reinterprets the head-major buffer as token-major. Each "row" fed to o_proj is one head's T tokens worth of activations, not one token's H heads. Compounds across 36 layers → "OwO".
student question Why have both Path A and Path B? Just pick one.
Path A is the spec the kernel implements. If your metadata is wrong, both will be wrong, but Path A's bug is in your own loop, easy to print. Path B's bug is inside a kernel — much harder to read. Build Path A first, verify the smoke test passes, then swap to Path B as a one-line substitution.
3.8 Acceptance
L3 pass criteria
scripts/l3_smoke.py emits the same 20 token IDs as L1 / L2. Optionally Path B (flash-attn) produces identical output.
3.9 What L3 unlocks
The metadata you just built (cu_seqlens_q, seq_lens_kv, block_table) is exactly what L5's scheduler will produce per step from a list of in-flight requests. That's the entire scheduler in one sentence. The hard data-structure work is done; L4 (sampler) and L5 (scheduler) become small surgical additions on top of this skeleton.