mini_sglang

L1 — Model & weight loading

Build Qwen3-8B from spec, load HF safetensors, match HF token-for-token.

1.1 Goal

By the end of this lesson, python -m scripts.l1_smoke prints "L1 PASS" and emits the same 20 tokens as Hugging Face on the same prompt. The KV cache is still a contiguous per-layer tensor ([max_len, H_kv, D]); paging arrives in L2.

1.2 Qwen3-8B at a glance

fieldvaluenote
num_hidden_layers36
hidden_size4096
intermediate_size (MLP)12288SwiGLU, no bias
num_attention_heads32
num_key_value_heads8GQA, group=4
head_dim128
RoPE θ1e6"large theta"
rms_norm_eps1e-6
tie_word_embeddingsfalseseparate lm_head
q_norm / k_normper-headQwen3 quirk; applied before RoPE
sliding_windownullfull attention
dtypebf16

1.3 Implementation order

  1. RMSNorm. x * rsqrt(mean(x²) + eps) * weight. Compute in fp32, cast back.
  2. SwiGLU MLP. down(silu(gate(x)) * up(x)). No bias.
  3. RoPE precompute. inv_freq = 1 / θ ** (arange(0, D, 2) / D); build cos/sin tables of shape [max_pos, D] using HF's rotate-half layout (cat(freqs, freqs), NOT interleaved).
  4. RoPE apply. x * cos + rotate_half(x) * sin, where rotate_half(x) = cat(-x[..., D/2:], x[..., :D/2], dim=-1).
  5. Qwen3Attention. Project Q/K/V (no bias), reshape to [T, H, D], q_norm and k_norm per head, apply RoPE, write to KV cache, gather-then-SDPA.
  6. Qwen3DecoderLayer. Standard pre-norm: x + attn(norm(x)), x + mlp(norm(x)).
  7. Qwen3Model. embed → 36 layers → final norm.
  8. Qwen3ForCausalLM. + lm_head.
  9. weights.py. Load safetensors shards, remap names, load_state_dict(strict=True).

1.4 Reference snippets (just the tricky ones)

RMSNorm — cast order matters

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps
    def forward(self, x):
        in_dtype = x.dtype
        x32 = x.to(torch.float32)
        var = x32.pow(2).mean(-1, keepdim=True)
        x32 = x32 * torch.rsqrt(var + self.eps)
        # NOTE: weight stays bf16; multiply happens after cast back
        return self.weight * x32.to(in_dtype)

RoPE precompute — the parens trap

def precompute_rope_cache(D, max_pos, base=1e6, device="cuda", dtype=torch.bfloat16):
    arange = torch.arange(0, D, 2, device=device, dtype=torch.float32)
    # CORRECT: base ** (arange / D)
    inv_freq = 1.0 / (base ** (arange / D))
    t = torch.arange(max_pos, device=device, dtype=torch.float32)
    freqs = torch.einsum("i,j->ij", t, inv_freq)        # [max_pos, D/2]
    emb   = torch.cat((freqs, freqs), dim=-1)           # [max_pos, D]   rotate-half layout
    return emb.cos().to(dtype), emb.sin().to(dtype)
pitfall Python evaluates base ** arange / D as (base ** arange) / D. Without the inner parens, inv_freq[0] ≈ 1/base ≠ 1 and your positional encoding is hilariously wrong. This bit us at L1 debug; the model produced sensible token 0 (RoPE doesn't matter at pos 0) and garbage from token 1 onward.

Qwen3Attention — quirks

q = self.q_proj(x).view(T, H_q, D)
k = self.k_proj(x).view(T, H_kv, D)
v = self.v_proj(x).view(T, H_kv, D)

# Qwen3-specific: per-head RMSNorm BEFORE rope
q = self.q_norm(q)
k = self.k_norm(k)

q, k = apply_rope(q, k, cos[positions], sin[positions])

# write to contiguous KV (L1) — paged in L2
K_cache[layer, kv_write_offset : kv_write_offset + T] = k
V_cache[layer, kv_write_offset : kv_write_offset + T] = v
K = K_cache[layer, :kv_valid_len]
V = V_cache[layer, :kv_valid_len]

# GQA: repeat KV heads to match Q heads
if H_q != H_kv:
    K = K.repeat_interleave(H_q // H_kv, dim=1)
    V = V.repeat_interleave(H_q // H_kv, dim=1)

# is_causal only when T == kv_valid_len (prefill, no prior cache)
is_prefill = (T == kv_valid_len)
out = F.scaled_dot_product_attention(
    Q.transpose(0,1).unsqueeze(0), K.transpose(0,1).unsqueeze(0), V.transpose(0,1).unsqueeze(0),
    is_causal=is_prefill, dropout_p=0.0,
).squeeze(0).transpose(0,1)

1.5 Q&A from session

student question Why is is_prefill = (T == kv_valid_len) true?

It's true by caller convention. Two cases:

stepT (queries this call)kv_valid_len (total cached)is_prefill
prefill of N-token promptNN (we just wrote N)true
decode step i1N + ifalse

So T == kv_valid_len is shorthand for "this batch's queries cover the entire history". When it's true, SDPA's built-in causal mask is the right mask. When it's false, the new token attends to all history (no mask).

student push-back "This only works because we pass kv_valid_len = T during prefill. It's a coupling smell."

Correct instinct. The cleaner formulation is per-sequence: q_len_i == kv_len_i. We'll switch to that representation in L3 once cu_seqlens/seq_lens exist. For L1's contiguous, single-sequence case, the boolean is fine and it removes one parameter from the forward signature.

student question How does the KV cache actually help? What computation did we save?

At decode step i with history length L = N + i, without the cache, you'd recompute K and V for all L positions (because each position's K, V depends only on its hidden state at that layer, but you'd have to recompute the hidden states too). With the cache:

FLOPs/token/layer for Qwen3-8B (4096 hidden, 12288 MLP, 32 heads, 128 head_dim):

  Per-token compute (decode with cache):  ~234 MFLOPs  / token / layer
  Per-step recompute (no cache, L=1024):   ~234 MFLOPs * 1024  ≈ 240 GFLOPs / step / layer
  Speedup at L=1024: ~1024×                  (roughly: linear in context length)

The KV cache turns generation from O(L²) per step into O(L) per step. The whole reason "fast inference" is even feasible.

1.6 The smoke harness

Loading both your model and HF simultaneously OOMs on a 24-32 GiB GPU (Qwen3-8B bf16 ≈ 16 GiB each). Load sequentially:

def run_ours(prompt_ids):
    model = load_model(MODEL_DIR); model.eval()
    out = our_greedy(model, prompt_ids, N_NEW).cpu()
    del model
    gc.collect(); torch.cuda.empty_cache(); torch.cuda.synchronize()
    return out

def run_hf(prompt_ids):
    hf = AutoModelForCausalLM.from_pretrained(MODEL_DIR, torch_dtype=torch.bfloat16).cuda().eval()
    hf_out = hf.generate(prompt_ids[None], do_sample=False, max_new_tokens=N_NEW)
    out = hf_out[0, prompt_ids.numel():].cpu()
    del hf, hf_out
    gc.collect(); torch.cuda.empty_cache(); torch.cuda.synchronize()
    return out

1.7 Acceptance

L1 pass criteria our_ids == hf_ids bit-equal for 20 greedy tokens on "The capital of France is". Expected: [12095, 13, 576, 6722, 315, 9625, 374, 12095, 13, 576, 6722, 315, 9625, 374, 12095, 13, 576, 6722, 315, 9625] → "Paris. The capital of France is Paris. The capital of France is Paris. The capital of France".

1.8 Debug log (real bugs we hit)

← L0 L2 — Paged KV cache →