mini_sglang

L5 — Scheduler / continuous batching

Multiple requests in one forward pass. Where most of the throughput comes from.

5.1 The problem

Three requests arrive on a busy server:

t=0    req A arrives, prompt = 8 tokens
t=10   req B arrives, prompt = 32 tokens   (A is mid-decode)
t=15   req C arrives, prompt = 5 tokens    (A and B both mid-decode)

A naïve "one request per forward call" server processes A fully, then B, then C. Tail latency is awful and the GPU sits idle during single-token decodes (Qwen3-8B's decode kernel utilises maybe 5–15% of an H100's compute).

Continuous batching: every forward step, ask "who's runnable?" and pack them all into a single call. The model doesn't change. The KV pool doesn't change. What changes is the bookkeeping that builds ForwardMeta — instead of one request's worth, it's whatever's runnable, mixed.

step 0:   prefill A (8 q tokens)         + nothing else
step 1:   decode A (1)                   + prefill B (32 q tokens)
step 2:   decode A (1) + decode B (1)    + prefill C (5 q tokens)
step 3:   decode A (1) + decode B (1)    + decode C (1)
…

The kernel call shape goes from "one sequence at a time" to "packed batch of mixed q_lens" — which is exactly why we built cu_seqlens_q / seq_lens_kv / block_table in L3. L5 is the consumer that justifies that design.

5.2 The scheduler's public API

Three methods. This is the surface every modern engine exposes to the server layer.

class Scheduler:
    def add_request(self, req: Request) -> None: ...
    def step(self) -> StepResult: ...                 # one model forward
    def has_unfinished(self) -> bool: ...

@dataclass
class StepResult:
    new_tokens: dict[int, int]   # req_id → token sampled this step
    finished:   list[int]        # req_ids that completed this step

The server loop is then:

while scheduler.has_unfinished() or pending:
    while pending:
        scheduler.add_request(pending.popleft())
    out = scheduler.step()
    for rid, tok in out.new_tokens.items():
        stream_to_client(rid, tok)
    for rid in out.finished:
        close_stream(rid)

5.3 The two queues

            ┌──────────────┐  select admits up to MAX_BATCH
new req ──► │ WAITING      │ ─────────────────────────────┐
            │ (no KV yet)  │                              │
            └──────────────┘                              ▼
                                                  ┌──────────────┐
            EOS / max_tokens ◄─── writeback ◄──── │ RUNNING      │
                                                  │ (has KV; will │
                                                  │  advance ≥ 1) │
                                                  └──────────────┘

WAITING requests need to be prefilled. RUNNING requests need to be decoded. Each step mixes some of both.

5.4 Four scheduling decisions

A. Admission

def select(self):
    decode_batch  = list(self.running)
    prefill_batch = []
    budget = MAX_TOKENS_PER_STEP - len(decode_batch)
    while self.waiting and budget > 0 and len(decode_batch)+len(prefill_batch) < MAX_BATCH:
        head = self.waiting[0]
        need = len(head.prompt_ids)
        if need > budget: break
        if self.alloc.num_free_tokens() < need + 1: break
        prefill_batch.append((self.waiting.popleft(), need))
        budget -= need
    return decode_batch, prefill_batch

B. Reservation

For each request, reserve cur_len + q_len total slots (decode needs +1, prefill needs +q_len). Pre-check all allocations before touching the model — partial allocation leaves you on inconsistent state.

C. Packing — the ForwardMeta builder

def build_meta(self, batch):
    positions, slot_mapping, cu_seqlens_q, seq_lens_kv = [], [], [0], []
    block_table_rows = []
    max_blocks = max(len(r.blocks) for r, _ in batch)

    for r, q_len in batch:
        for p in range(r.cur_len, r.cur_len + q_len):
            positions.append(p)
            slot_mapping.append(r.slot_indices[p])
        cu_seqlens_q.append(cu_seqlens_q[-1] + q_len)
        seq_lens_kv.append(r.cur_len + q_len)
        block_table_rows.append(r.blocks + [-1] * (max_blocks - len(r.blocks)))

    return ForwardMeta(
        positions    = torch.tensor(positions,    device='cuda', dtype=torch.long),
        slot_mapping = torch.tensor(slot_mapping, device='cuda', dtype=torch.long),
        cu_seqlens_q = torch.tensor(cu_seqlens_q, device='cuda', dtype=torch.int32),
        seq_lens_kv  = torch.tensor(seq_lens_kv,  device='cuda', dtype=torch.int32),
        block_table  = torch.tensor(block_table_rows, device='cuda', dtype=torch.int32),
        block_size   = self.alloc.block_size,
    )

is_prefill dies here. In a mixed batch, a single boolean has no global meaning. attn_paged_torch already handles per-seq causality via q_len == kv_len.

D. Writeback

The model returns [T_total, vocab] logits. Pick the last row of each sequence:

last_rows = (meta.cu_seqlens_q[1:] - 1).long()      # [num_seqs]
per_seq   = logits[last_rows]                       # [num_seqs, vocab]
next_ids  = sampler(per_seq.float(), params, prev_ids)   # [num_seqs]

for (r, q_len), tok in zip(batch, next_ids.tolist()):
    r.output_ids.append(tok)
    r.cur_len += q_len                              # NOT always +1
    if tok == eos_id or len(r.output_ids) >= r.max_tokens:
        finished.append(r.id)
    else:
        self.running.append(r)                      # requeue!
off-by-one trap cur_len += q_len, NOT += 1. A prefill of 32 tokens advances cur_len from 0 to 32 in one step. A decode advances by 1.

5.5 Per-request sampling params

def build_params(self, batch):
    return SamplingParams(
        temperature = torch.tensor([r.sampling_param["temperature"] for r,_ in batch],
                                   dtype=torch.float32, device='cuda'),
        top_k       = torch.tensor([r.sampling_param["top_k"]       for r,_ in batch],
                                   dtype=torch.int32,   device='cuda'),
        top_p       = torch.tensor([r.sampling_param["top_p"]       for r,_ in batch],
                                   dtype=torch.float32, device='cuda'),
        rep_penalty = torch.tensor([r.sampling_param["rep_penalty"] for r,_ in batch],
                                   dtype=torch.float32, device='cuda'),
    )

prev_ids is right-padded to a common length so the sampler's gather works:

def build_prev_ids(self, batch):
    max_prev = max(r.prompt_ids.numel() + len(r.output_ids) for r,_ in batch)
    out = torch.full((len(batch), max_prev), 0, dtype=torch.long, device='cuda')
    for i, (r, _) in enumerate(batch):
        hist = torch.cat([r.prompt_ids,
                          torch.tensor(r.output_ids, dtype=r.prompt_ids.dtype, device='cuda')])
        out[i, :hist.numel()] = hist
    return out

5.6 Q&A from session

student question Why logits[-1:] in L4? Are we ignoring the last position?

Backwards — [-1:] keeps the last row (kept-dim slice). The rows we throw away are 0..T-2, which predict tokens the prompt already provided (teacher-forced). Sampling them would be wasted compute — we only care what the model says about the next token.

student question In L5 with a batch dim, do we do logits[:, -1:]?

No — our engine uses packed/varlen layout, not padded batch. Logits is [T_total, vocab], no batch axis. We pick per-seq last rows with logits[(cu_seqlens_q[1:] - 1).long()]. Walked through with q_lens=[5,1,3]: cu_seqlens_q=[0,5,6,9] → pick rows [4,5,8] of the [9, vocab] tensor.

The shape, visualised

q_lens          = [5, 1, 3]                    # 3 sequences
T_total         = 9
logits.shape    = [9, vocab]

row index:   0  1  2  3  4  | 5 |  6  7  8
belongs to:  A  A  A  A  A  | B |  C  C  C
                          ↑    ↑          ↑
                       last  last      last

cu_seqlens_q       = [0, 5, 6, 9]
cu_seqlens_q[1:]-1 = [4, 5, 8]                  ← gather indices

logits[[4, 5, 8]]  →  shape [3, vocab]          ← per-seq last-row logits
sampler(...)       →  shape [3]                 ← one token per request

5.7 Pitfalls (table of nine)

pitfallsymptom
cur_len += 1 alwaysstate drifts on prefill steps
not bumping cur_len before next selectinfinite re-prefill
using output_ids[-1] for prefillswrong: prefills consume next chunk of prompt
block_table padded with a valid block idkernel reads other request's KV (defensive bug)
selecting decodes whose cur_len already advanceddouble-write to same slot
prev_ids length mismatch row-to-rowsampler gather crashes
eos_id confused with pad_idpremature stop
logits[-1:] with multi-seq batchonly the last sequence's token gets sampled
reserve uses cur_len not cur_len + q_lenfresh prefill reserves zero slots → IndexError

5.8 The batch-invariance problem (verified)

Greedy outputs from a multi-request batch can differ from solo runs because cuBLAS picks different matmul algorithms based on the input shape. A [5, hidden] matmul and a [15, hidden] matmul have different reduction orders → ULP-scale fp drift → argmax flips when two top logits are close.

Demonstrated in the session: prompt 0 ("capital of France") run alone matches the L1-L4 reference exactly. Run in a batch of 3, it diverges at token 5 (Italy instead of France). Prompt 2 (longest, last) matches in both because its row is the last row of the whole tensor — exactly the layout the single-seq case sees.

student question How do we achieve true deterministic parity?

Three layers:

  1. Env + flags (run-to-run reproducibility, ~5–15% perf):
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    torch.use_deterministic_algorithms(True, warn_only=True)
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
  2. SDPA pin to MATH backend (attention batch invariance, 2–5× slower attention):
    torch.backends.cuda.enable_flash_sdp(False)
    torch.backends.cuda.enable_mem_efficient_sdp(False)
    torch.backends.cuda.enable_math_sdp(True)
  3. Pad projections to M_max or custom Triton kernels (projection batch invariance). Expensive, beyond curriculum.

Honest reality: no production engine (vLLM, SGLang, TRT-LLM, DeepSpeed) fully achieves bit-equal batched parity. It's a verification-time tool, not a production capability. SGLang has open tickets on batch invariance.

5.9 Acceptance

L5 pass criteria (the right ones) Bit-equality across batched vs solo is the wrong test — see §5.8.
← L4 L6 — Tokenizer →