L5 — Scheduler / continuous batching
5.1 The problem
Three requests arrive on a busy server:
t=0 req A arrives, prompt = 8 tokens
t=10 req B arrives, prompt = 32 tokens (A is mid-decode)
t=15 req C arrives, prompt = 5 tokens (A and B both mid-decode)
A naïve "one request per forward call" server processes A fully, then B, then C. Tail latency is awful and the GPU sits idle during single-token decodes (Qwen3-8B's decode kernel utilises maybe 5–15% of an H100's compute).
Continuous batching: every forward step, ask "who's runnable?" and pack them all into a single call. The model doesn't change. The KV pool doesn't change. What changes is the bookkeeping that builds ForwardMeta — instead of one request's worth, it's whatever's runnable, mixed.
step 0: prefill A (8 q tokens) + nothing else
step 1: decode A (1) + prefill B (32 q tokens)
step 2: decode A (1) + decode B (1) + prefill C (5 q tokens)
step 3: decode A (1) + decode B (1) + decode C (1)
…
The kernel call shape goes from "one sequence at a time" to "packed batch of mixed q_lens" — which is exactly why we built cu_seqlens_q / seq_lens_kv / block_table in L3. L5 is the consumer that justifies that design.
5.2 The scheduler's public API
Three methods. This is the surface every modern engine exposes to the server layer.
class Scheduler:
def add_request(self, req: Request) -> None: ...
def step(self) -> StepResult: ... # one model forward
def has_unfinished(self) -> bool: ...
@dataclass
class StepResult:
new_tokens: dict[int, int] # req_id → token sampled this step
finished: list[int] # req_ids that completed this step
The server loop is then:
while scheduler.has_unfinished() or pending:
while pending:
scheduler.add_request(pending.popleft())
out = scheduler.step()
for rid, tok in out.new_tokens.items():
stream_to_client(rid, tok)
for rid in out.finished:
close_stream(rid)
5.3 The two queues
┌──────────────┐ select admits up to MAX_BATCH
new req ──► │ WAITING │ ─────────────────────────────┐
│ (no KV yet) │ │
└──────────────┘ ▼
┌──────────────┐
EOS / max_tokens ◄─── writeback ◄──── │ RUNNING │
│ (has KV; will │
│ advance ≥ 1) │
└──────────────┘
WAITING requests need to be prefilled. RUNNING requests need to be decoded. Each step mixes some of both.
5.4 Four scheduling decisions
A. Admission
def select(self):
decode_batch = list(self.running)
prefill_batch = []
budget = MAX_TOKENS_PER_STEP - len(decode_batch)
while self.waiting and budget > 0 and len(decode_batch)+len(prefill_batch) < MAX_BATCH:
head = self.waiting[0]
need = len(head.prompt_ids)
if need > budget: break
if self.alloc.num_free_tokens() < need + 1: break
prefill_batch.append((self.waiting.popleft(), need))
budget -= need
return decode_batch, prefill_batch
B. Reservation
For each request, reserve cur_len + q_len total slots (decode needs +1, prefill needs +q_len). Pre-check all allocations before touching the model — partial allocation leaves you on inconsistent state.
C. Packing — the ForwardMeta builder
def build_meta(self, batch):
positions, slot_mapping, cu_seqlens_q, seq_lens_kv = [], [], [0], []
block_table_rows = []
max_blocks = max(len(r.blocks) for r, _ in batch)
for r, q_len in batch:
for p in range(r.cur_len, r.cur_len + q_len):
positions.append(p)
slot_mapping.append(r.slot_indices[p])
cu_seqlens_q.append(cu_seqlens_q[-1] + q_len)
seq_lens_kv.append(r.cur_len + q_len)
block_table_rows.append(r.blocks + [-1] * (max_blocks - len(r.blocks)))
return ForwardMeta(
positions = torch.tensor(positions, device='cuda', dtype=torch.long),
slot_mapping = torch.tensor(slot_mapping, device='cuda', dtype=torch.long),
cu_seqlens_q = torch.tensor(cu_seqlens_q, device='cuda', dtype=torch.int32),
seq_lens_kv = torch.tensor(seq_lens_kv, device='cuda', dtype=torch.int32),
block_table = torch.tensor(block_table_rows, device='cuda', dtype=torch.int32),
block_size = self.alloc.block_size,
)
is_prefill dies here. In a mixed batch, a single boolean has no global meaning. attn_paged_torch already handles per-seq causality via q_len == kv_len.
D. Writeback
The model returns [T_total, vocab] logits. Pick the last row of each sequence:
last_rows = (meta.cu_seqlens_q[1:] - 1).long() # [num_seqs]
per_seq = logits[last_rows] # [num_seqs, vocab]
next_ids = sampler(per_seq.float(), params, prev_ids) # [num_seqs]
for (r, q_len), tok in zip(batch, next_ids.tolist()):
r.output_ids.append(tok)
r.cur_len += q_len # NOT always +1
if tok == eos_id or len(r.output_ids) >= r.max_tokens:
finished.append(r.id)
else:
self.running.append(r) # requeue!
off-by-one trapcur_len += q_len, NOT+= 1. A prefill of 32 tokens advancescur_lenfrom 0 to 32 in one step. A decode advances by 1.
5.5 Per-request sampling params
def build_params(self, batch):
return SamplingParams(
temperature = torch.tensor([r.sampling_param["temperature"] for r,_ in batch],
dtype=torch.float32, device='cuda'),
top_k = torch.tensor([r.sampling_param["top_k"] for r,_ in batch],
dtype=torch.int32, device='cuda'),
top_p = torch.tensor([r.sampling_param["top_p"] for r,_ in batch],
dtype=torch.float32, device='cuda'),
rep_penalty = torch.tensor([r.sampling_param["rep_penalty"] for r,_ in batch],
dtype=torch.float32, device='cuda'),
)
prev_ids is right-padded to a common length so the sampler's gather works:
def build_prev_ids(self, batch):
max_prev = max(r.prompt_ids.numel() + len(r.output_ids) for r,_ in batch)
out = torch.full((len(batch), max_prev), 0, dtype=torch.long, device='cuda')
for i, (r, _) in enumerate(batch):
hist = torch.cat([r.prompt_ids,
torch.tensor(r.output_ids, dtype=r.prompt_ids.dtype, device='cuda')])
out[i, :hist.numel()] = hist
return out
5.6 Q&A from session
student question
Why logits[-1:] in L4? Are we ignoring the last position?
Backwards — [-1:] keeps the last row (kept-dim slice). The rows we throw away are 0..T-2, which predict tokens the prompt already provided (teacher-forced). Sampling them would be wasted compute — we only care what the model says about the next token.
student question
In L5 with a batch dim, do we do logits[:, -1:]?
No — our engine uses packed/varlen layout, not padded batch. Logits is [T_total, vocab], no batch axis. We pick per-seq last rows with logits[(cu_seqlens_q[1:] - 1).long()]. Walked through with q_lens=[5,1,3]: cu_seqlens_q=[0,5,6,9] → pick rows [4,5,8] of the [9, vocab] tensor.
The shape, visualised
q_lens = [5, 1, 3] # 3 sequences
T_total = 9
logits.shape = [9, vocab]
row index: 0 1 2 3 4 | 5 | 6 7 8
belongs to: A A A A A | B | C C C
↑ ↑ ↑
last last last
cu_seqlens_q = [0, 5, 6, 9]
cu_seqlens_q[1:]-1 = [4, 5, 8] ← gather indices
logits[[4, 5, 8]] → shape [3, vocab] ← per-seq last-row logits
sampler(...) → shape [3] ← one token per request
5.7 Pitfalls (table of nine)
| pitfall | symptom |
|---|---|
cur_len += 1 always | state drifts on prefill steps |
not bumping cur_len before next select | infinite re-prefill |
using output_ids[-1] for prefills | wrong: prefills consume next chunk of prompt |
block_table padded with a valid block id | kernel reads other request's KV (defensive bug) |
selecting decodes whose cur_len already advanced | double-write to same slot |
prev_ids length mismatch row-to-row | sampler gather crashes |
eos_id confused with pad_id | premature stop |
logits[-1:] with multi-seq batch | only the last sequence's token gets sampled |
reserve uses cur_len not cur_len + q_len | fresh prefill reserves zero slots → IndexError |
5.8 The batch-invariance problem (verified)
Greedy outputs from a multi-request batch can differ from solo runs because cuBLAS picks different matmul algorithms based on the input shape. A [5, hidden] matmul and a [15, hidden] matmul have different reduction orders → ULP-scale fp drift → argmax flips when two top logits are close.
Demonstrated in the session: prompt 0 ("capital of France") run alone matches the L1-L4 reference exactly. Run in a batch of 3, it diverges at token 5 (Italy instead of France). Prompt 2 (longest, last) matches in both because its row is the last row of the whole tensor — exactly the layout the single-seq case sees.
student question How do we achieve true deterministic parity?
Three layers:
- Env + flags (run-to-run reproducibility, ~5–15% perf):
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" torch.use_deterministic_algorithms(True, warn_only=True) torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False - SDPA pin to MATH backend (attention batch invariance, 2–5× slower attention):
torch.backends.cuda.enable_flash_sdp(False) torch.backends.cuda.enable_mem_efficient_sdp(False) torch.backends.cuda.enable_math_sdp(True) - Pad projections to M_max or custom Triton kernels (projection batch invariance). Expensive, beyond curriculum.
Honest reality: no production engine (vLLM, SGLang, TRT-LLM, DeepSpeed) fully achieves bit-equal batched parity. It's a verification-time tool, not a production capability. SGLang has open tickets on batch invariance.
5.9 Acceptance
L5 pass criteria (the right ones)Bit-equality across batched vs solo is the wrong test — see §5.8.
- A. Solo-through-scheduler matches L1–L4 reference bit-for-bit (verified ✓).
- B. Batched run produces coherent text per request and the right token count.
- C. At least one mixed prefill+decode step occurred.