mini_sglang

L4 — Sampler

Factor token selection into a real module: greedy, temperature, top-k, top-p, penalties.

4.1 Where the sampler lives

logits_full = model(input_ids, pool, meta)         # [T_total, vocab]
logits      = logits_full[last_token_per_seq]      # [num_seqs, vocab]  ← scheduler picks rows
next_ids    = sampler(logits, sampling_params)     # [num_seqs]         ← THIS lesson
req.output_ids.append(next_ids[s])                 # writeback per seq

Two structural points:

  1. The model returns logits for every token in the packed input. The sampler input is whatever rows the scheduler picks — typically the last token of each sequence: logits[cu_seqlens_q[1:] - 1]. The -1 matters because cu_seqlens_q is exclusive-end.
  2. The sampler is per-sequence, not per-batch. Different requests carry different temperature, top_p, etc. Parameters must be 1-D tensors of length num_seqs, not scalars.

This is the only stateful, per-request thing in the forward path — the rest of the model is a pure tensor function. That's why sampling lives outside the model graph.

4.2 The canonical pipeline order

logits ─► penalties ─► / temperature ─► top-k mask ─► top-p mask ─► softmax ─► multinomial
                                                                       │
                                                          greedy rows: argmax override

Order is litigated. Common bugs:

4.3 Building blocks

A. Greedy

next_id = logits.argmax(dim=-1)   # [num_seqs]

Equivalent to temperature == 0. The L3 driver does this inline; the sampler wraps it.

B. Temperature (with the divide-by-zero guard)

mask_greedy = (temperature == 0)
t_safe      = temperature.masked_fill(mask_greedy, 1.0)
logits      = logits / t_safe.unsqueeze(-1)         # [num_seqs, vocab] / [num_seqs, 1]
# ...at the end, override greedy rows with argmax
next_ids = torch.where(mask_greedy, logits.argmax(-1), next_ids)

Two reasons not to skip the guard: division by zero gives inf/nan, and even 1e-3 temperature makes runner-up logits underflow to nan after softmax.

C. Top-k

Keep only the top k logits per row; everything else becomes -inf.

if k.max() > 0:
    k_max = int(k.max())
    topk_vals, topk_idx = logits.topk(k_max, dim=-1)
    keep = torch.arange(k_max, device=logits.device)[None, :] < k.unsqueeze(-1)
    new  = torch.full_like(logits, float('-inf'))
    new.scatter_(-1, topk_idx,
                 torch.where(keep, topk_vals,
                             torch.tensor(float('-inf'), device=logits.device)))
    logits = torch.where((k > 0).unsqueeze(-1), new, logits)

Per-row k via topk(max_k) + a "keep only first k[i] of each row" mask. Rows with k=0 bypass top-k entirely via the final torch.where.

D. Top-p (nucleus)

Keep the smallest set of tokens whose cumulative probability ≥ p.

if (p < 1.0).any():
    sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True)
    sorted_probs = sorted_logits.softmax(dim=-1)
    cum          = sorted_probs.cumsum(dim=-1)

    remove = cum > p.unsqueeze(-1)
    # SHIFT RIGHT: the boundary token (the one that pushes past p) must be kept
    remove[..., 1:] = remove[..., :-1].clone()
    remove[..., 0]  = False

    # project sorted mask back to vocab order
    remove_unsorted = torch.empty_like(remove)
    remove_unsorted.scatter_(-1, sorted_idx, remove)
    logits = logits.masked_fill(remove_unsorted, float('-inf'))
pitfall The shift-right is not optional. Without it, when the top-prob token is already ≥ p, the mask removes it — and there's nothing left to sample from. Silent OOD.

E. Penalties

# repetition penalty (multiplicative, HF style)
if (rep_p != 1.0).any():
    score = logits.gather(-1, prev_ids)                                 # [N, max_prev]
    score = torch.where(score > 0,
                        score / rep_p.unsqueeze(-1),
                        score * rep_p.unsqueeze(-1))
    logits.scatter_(-1, prev_ids, score)                                # IN-PLACE

prev_ids shape contract: [num_seqs, max_prev]. Right-pad with a sentinel (or just repeat the first token) so all rows have the same length. The scatter overwrites; if you pad with valid token ids you'll over-penalise — pad with something inert and mask the gather/scatter, or use frequency-count tensors instead (cleaner mathematically).

F. Final draw

probs    = logits.float().softmax(dim=-1)         # cast inside sampler, don't trust caller
next_ids = torch.multinomial(probs, num_samples=1).squeeze(-1)
next_ids = torch.where(mask_greedy, logits.argmax(-1), next_ids)

The greedy override is what keeps the smoke test parity with L3 at temperature=0.

4.4 Disabled-default sentinels (memorize)

knobtypedisabled valuewhy
temperaturefloat0.0triggers greedy override
top_kint0convention: 0 = off
top_pfloat1.0identity — keeps the whole vocab
rep_penaltyfloat1.0multiplicative identity

If your "disabled" smoke test passes top_p=0.0 you've removed the entire vocab; only the greedy override saves you. rep_penalty=0.0 multiplies logits by zero — total quality collapse if it ever ran.

4.5 Q&A from session

student question Why does the sampler need per-row tensors and not just scalars?

Because once the scheduler (L5) bunches multiple in-flight requests into one batch, request A might want temperature=0 (you, debugging) while request B wants temperature=0.7 (production user). The whole point of batching is that we run them in one kernel call. So sampling parameters become per-row from day one, even if the L4 smoke uses a single row.

student question Why softmax twice (once for top-p, once for the final draw)?

The top-p softmax is a means to an end: it converts logits into a CDF so we can find the nucleus. After we mask logits to -inf, the final softmax over the masked logits gives the correct re-normalised distribution. There's no algebraic way to do it in one step that keeps the masking unbiased.

4.6 Implementation review (bugs we hit)

From reviewing the student implementation:

  1. Disabled-default sentinels were wrong. Driver passed top_p=0.0 and rep_penalty=0.0; correct sentinels are 1.0 for both. The smoke "passed" only because the greedy override masks the damage.
  2. Sampler ignored its own argument. rep_penalty was hardcoded to 1.2 inside the sampler instead of reading sample_cfg.rep_penalty.
  3. Non-inplace scatter. logits.scatter(-1, prev_ids, score) returns a new tensor; the original logits was unchanged. Use scatter_.
  4. prev_ids shape mismatch. Sampler did prev_ids.unsqueeze(-1) to make it 2-D, but the result had num_prev on dim 0 instead of num_seqs. gather raises as soon as num_prev > 1. The contract is [num_seqs, max_prev], period; the driver should build it as torch.stack(output_ids).unsqueeze(0) for single-seq.
  5. Top-p always runs. No if (top_p < 1.0).any() short-circuit; pays a full sort even when disabled.
  6. Softmax precision left to caller. Move .float() inside the sampler so it's not load-bearing on the driver.
  7. Garbage IDE auto-import at top of smoke (SAMPLE_INPUTS_FILENAME_FORMAT). Delete.
  8. Wrong pass banner. Smoke prints "L3 PASS" instead of "L4 PASS".

4.7 Implementation order

  1. Define SamplingParams + per-row tensor packing helper.
  2. Greedy-only Sampler.__call__. Wire into smoke. Confirm same 20 IDs as L3.
  3. Add temperature with the == 0 guard. Re-test at temp=0.
  4. Add top-k (and its k=0 bypass). Re-test.
  5. Add top-p (with shift-right and p<1.0 guard). Re-test.
  6. Add repetition penalty (with scatter_ and rp!=1.0 guard). Re-test.
  7. Sanity check: at temp=0.7, top_p=0.9, top_k=50 over 64 tokens, output should be coherent English. Visual only, no bit-equality.

Every step keeps the smoke test as a regression: every disabled default must still produce the L3 output. If any step breaks parity with disabled defaults, an op escaped its if disabled: skip guard.

4.8 Pitfalls recap

pitfallsymptom
softmax before top-k/ptop-p threshold uses wrong distribution; silent quality drop
temperature == 0 not guardednan or all--inf after first sample
top-p without shift-rightfirst token can be removed → all-zeros softmax
bf16 multinomiallow-prob tokens collapse, distribution spikier than intended
per-row params passed as scalarsfirst request's setting silently shared with all
non-inplace scatter on logitspenalty silently does nothing
top_p=0 or rep_penalty=0 as "disabled"wrong sentinels; greedy override hides the bug

4.9 Acceptance

L4 pass criteria

When this passes, L5 (scheduler) becomes "build cu_seqlens_q, seq_lens_kv, block_table, slot_mapping, and the per-row sampling params from a list of in-flight Request objects every step". Mechanical.

← L3 Home