L4 — Sampler
4.1 Where the sampler lives
logits_full = model(input_ids, pool, meta) # [T_total, vocab]
logits = logits_full[last_token_per_seq] # [num_seqs, vocab] ← scheduler picks rows
next_ids = sampler(logits, sampling_params) # [num_seqs] ← THIS lesson
req.output_ids.append(next_ids[s]) # writeback per seq
Two structural points:
- The model returns logits for every token in the packed input. The sampler input is whatever rows the scheduler picks — typically the last token of each sequence:
logits[cu_seqlens_q[1:] - 1]. The-1matters becausecu_seqlens_qis exclusive-end. - The sampler is per-sequence, not per-batch. Different requests carry different temperature, top_p, etc. Parameters must be 1-D tensors of length
num_seqs, not scalars.
This is the only stateful, per-request thing in the forward path — the rest of the model is a pure tensor function. That's why sampling lives outside the model graph.
4.2 The canonical pipeline order
logits ─► penalties ─► / temperature ─► top-k mask ─► top-p mask ─► softmax ─► multinomial
│
greedy rows: argmax override
Order is litigated. Common bugs:
- Softmax before top-k/p — top-p's "cumulative ≤ p" math is wrong because it's truncating a distribution that's already been normalised on a different support.
- Temperature after top-k — top-k is rank-based so the surviving set is unchanged, but the relative probabilities inside the kept set are different than intended.
- Penalties after temperature — penalties are additive/multiplicative logit shifts. Their magnitude is calibrated on raw logits; temperature scaling changes their effective strength.
4.3 Building blocks
A. Greedy
next_id = logits.argmax(dim=-1) # [num_seqs]
Equivalent to temperature == 0. The L3 driver does this inline; the sampler wraps it.
B. Temperature (with the divide-by-zero guard)
mask_greedy = (temperature == 0)
t_safe = temperature.masked_fill(mask_greedy, 1.0)
logits = logits / t_safe.unsqueeze(-1) # [num_seqs, vocab] / [num_seqs, 1]
# ...at the end, override greedy rows with argmax
next_ids = torch.where(mask_greedy, logits.argmax(-1), next_ids)
Two reasons not to skip the guard: division by zero gives inf/nan, and even 1e-3 temperature makes runner-up logits underflow to nan after softmax.
C. Top-k
Keep only the top k logits per row; everything else becomes -inf.
if k.max() > 0:
k_max = int(k.max())
topk_vals, topk_idx = logits.topk(k_max, dim=-1)
keep = torch.arange(k_max, device=logits.device)[None, :] < k.unsqueeze(-1)
new = torch.full_like(logits, float('-inf'))
new.scatter_(-1, topk_idx,
torch.where(keep, topk_vals,
torch.tensor(float('-inf'), device=logits.device)))
logits = torch.where((k > 0).unsqueeze(-1), new, logits)
Per-row k via topk(max_k) + a "keep only first k[i] of each row" mask. Rows with k=0 bypass top-k entirely via the final torch.where.
D. Top-p (nucleus)
Keep the smallest set of tokens whose cumulative probability ≥ p.
if (p < 1.0).any():
sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True)
sorted_probs = sorted_logits.softmax(dim=-1)
cum = sorted_probs.cumsum(dim=-1)
remove = cum > p.unsqueeze(-1)
# SHIFT RIGHT: the boundary token (the one that pushes past p) must be kept
remove[..., 1:] = remove[..., :-1].clone()
remove[..., 0] = False
# project sorted mask back to vocab order
remove_unsorted = torch.empty_like(remove)
remove_unsorted.scatter_(-1, sorted_idx, remove)
logits = logits.masked_fill(remove_unsorted, float('-inf'))
pitfall The shift-right is not optional. Without it, when the top-prob token is already ≥ p, the mask removes it — and there's nothing left to sample from. Silent OOD.
E. Penalties
# repetition penalty (multiplicative, HF style)
if (rep_p != 1.0).any():
score = logits.gather(-1, prev_ids) # [N, max_prev]
score = torch.where(score > 0,
score / rep_p.unsqueeze(-1),
score * rep_p.unsqueeze(-1))
logits.scatter_(-1, prev_ids, score) # IN-PLACE
prev_ids shape contract: [num_seqs, max_prev]. Right-pad with a sentinel (or just repeat the first token) so all rows have the same length. The scatter overwrites; if you pad with valid token ids you'll over-penalise — pad with something inert and mask the gather/scatter, or use frequency-count tensors instead (cleaner mathematically).
F. Final draw
probs = logits.float().softmax(dim=-1) # cast inside sampler, don't trust caller
next_ids = torch.multinomial(probs, num_samples=1).squeeze(-1)
next_ids = torch.where(mask_greedy, logits.argmax(-1), next_ids)
The greedy override is what keeps the smoke test parity with L3 at temperature=0.
4.4 Disabled-default sentinels (memorize)
| knob | type | disabled value | why |
|---|---|---|---|
| temperature | float | 0.0 | triggers greedy override |
| top_k | int | 0 | convention: 0 = off |
| top_p | float | 1.0 | identity — keeps the whole vocab |
| rep_penalty | float | 1.0 | multiplicative identity |
If your "disabled" smoke test passes top_p=0.0 you've removed the entire vocab; only the greedy override saves you. rep_penalty=0.0 multiplies logits by zero — total quality collapse if it ever ran.
4.5 Q&A from session
student question Why does the sampler need per-row tensors and not just scalars?
Because once the scheduler (L5) bunches multiple in-flight requests into one batch, request A might want temperature=0 (you, debugging) while request B wants temperature=0.7 (production user). The whole point of batching is that we run them in one kernel call. So sampling parameters become per-row from day one, even if the L4 smoke uses a single row.
student question Why softmax twice (once for top-p, once for the final draw)?
The top-p softmax is a means to an end: it converts logits into a CDF so we can find the nucleus. After we mask logits to -inf, the final softmax over the masked logits gives the correct re-normalised distribution. There's no algebraic way to do it in one step that keeps the masking unbiased.
4.6 Implementation review (bugs we hit)
From reviewing the student implementation:
- Disabled-default sentinels were wrong. Driver passed
top_p=0.0andrep_penalty=0.0; correct sentinels are1.0for both. The smoke "passed" only because the greedy override masks the damage. - Sampler ignored its own argument.
rep_penaltywas hardcoded to 1.2 inside the sampler instead of readingsample_cfg.rep_penalty. - Non-inplace scatter.
logits.scatter(-1, prev_ids, score)returns a new tensor; the originallogitswas unchanged. Usescatter_. prev_idsshape mismatch. Sampler didprev_ids.unsqueeze(-1)to make it 2-D, but the result hadnum_prevon dim 0 instead ofnum_seqs.gatherraises as soon asnum_prev > 1. The contract is[num_seqs, max_prev], period; the driver should build it astorch.stack(output_ids).unsqueeze(0)for single-seq.- Top-p always runs. No
if (top_p < 1.0).any()short-circuit; pays a full sort even when disabled. - Softmax precision left to caller. Move
.float()inside the sampler so it's not load-bearing on the driver. - Garbage IDE auto-import at top of smoke (
SAMPLE_INPUTS_FILENAME_FORMAT). Delete. - Wrong pass banner. Smoke prints "L3 PASS" instead of "L4 PASS".
4.7 Implementation order
- Define
SamplingParams+ per-row tensor packing helper. - Greedy-only
Sampler.__call__. Wire into smoke. Confirm same 20 IDs as L3. - Add temperature with the
== 0guard. Re-test attemp=0. - Add top-k (and its
k=0bypass). Re-test. - Add top-p (with shift-right and
p<1.0guard). Re-test. - Add repetition penalty (with
scatter_andrp!=1.0guard). Re-test. - Sanity check: at
temp=0.7, top_p=0.9, top_k=50over 64 tokens, output should be coherent English. Visual only, no bit-equality.
Every step keeps the smoke test as a regression: every disabled default must still produce the L3 output. If any step breaks parity with disabled defaults, an op escaped its if disabled: skip guard.
4.8 Pitfalls recap
| pitfall | symptom |
|---|---|
| softmax before top-k/p | top-p threshold uses wrong distribution; silent quality drop |
temperature == 0 not guarded | nan or all--inf after first sample |
| top-p without shift-right | first token can be removed → all-zeros softmax |
| bf16 multinomial | low-prob tokens collapse, distribution spikier than intended |
| per-row params passed as scalars | first request's setting silently shared with all |
non-inplace scatter on logits | penalty silently does nothing |
top_p=0 or rep_penalty=0 as "disabled" | wrong sentinels; greedy override hides the bug |
4.9 Acceptance
L4 pass criteria
scripts/l4_smoke.pygreedy path produces the same 20 IDs as L1/L2/L3.Sampleraccepts per-rowtemperature,top_k,top_p,rep_penalty.- Each disabled-default setting individually preserves bit-equality.
- At
temp=0.7, top_p=0.9, top_k=50for 64 tokens, output is coherent English.
When this passes, L5 (scheduler) becomes "build cu_seqlens_q, seq_lens_kv, block_table, slot_mapping, and the per-row sampling params from a list of in-flight Request objects every step". Mechanical.