L1 — Model & weight loading
1.1 Goal
By the end of this lesson, python -m scripts.l1_smoke prints "L1 PASS" and emits the same 20 tokens as Hugging Face on the same prompt. The KV cache is still a contiguous per-layer tensor ([max_len, H_kv, D]); paging arrives in L2.
1.2 Qwen3-8B at a glance
| field | value | note |
|---|---|---|
| num_hidden_layers | 36 | |
| hidden_size | 4096 | |
| intermediate_size (MLP) | 12288 | SwiGLU, no bias |
| num_attention_heads | 32 | |
| num_key_value_heads | 8 | GQA, group=4 |
| head_dim | 128 | |
| RoPE θ | 1e6 | "large theta" |
| rms_norm_eps | 1e-6 | |
| tie_word_embeddings | false | separate lm_head |
| q_norm / k_norm | per-head | Qwen3 quirk; applied before RoPE |
| sliding_window | null | full attention |
| dtype | bf16 |
1.3 Implementation order
- RMSNorm.
x * rsqrt(mean(x²) + eps) * weight. Compute in fp32, cast back. - SwiGLU MLP.
down(silu(gate(x)) * up(x)). No bias. - RoPE precompute.
inv_freq = 1 / θ ** (arange(0, D, 2) / D); buildcos/sintables of shape[max_pos, D]using HF's rotate-half layout (cat(freqs, freqs), NOT interleaved). - RoPE apply.
x * cos + rotate_half(x) * sin, whererotate_half(x) = cat(-x[..., D/2:], x[..., :D/2], dim=-1). - Qwen3Attention. Project Q/K/V (no bias), reshape to
[T, H, D], q_norm and k_norm per head, apply RoPE, write to KV cache, gather-then-SDPA. - Qwen3DecoderLayer. Standard pre-norm:
x + attn(norm(x)),x + mlp(norm(x)). - Qwen3Model. embed → 36 layers → final norm.
- Qwen3ForCausalLM. +
lm_head. - weights.py. Load safetensors shards, remap names,
load_state_dict(strict=True).
1.4 Reference snippets (just the tricky ones)
RMSNorm — cast order matters
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
in_dtype = x.dtype
x32 = x.to(torch.float32)
var = x32.pow(2).mean(-1, keepdim=True)
x32 = x32 * torch.rsqrt(var + self.eps)
# NOTE: weight stays bf16; multiply happens after cast back
return self.weight * x32.to(in_dtype)
RoPE precompute — the parens trap
def precompute_rope_cache(D, max_pos, base=1e6, device="cuda", dtype=torch.bfloat16):
arange = torch.arange(0, D, 2, device=device, dtype=torch.float32)
# CORRECT: base ** (arange / D)
inv_freq = 1.0 / (base ** (arange / D))
t = torch.arange(max_pos, device=device, dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, inv_freq) # [max_pos, D/2]
emb = torch.cat((freqs, freqs), dim=-1) # [max_pos, D] rotate-half layout
return emb.cos().to(dtype), emb.sin().to(dtype)
pitfall Python evaluatesbase ** arange / Das(base ** arange) / D. Without the inner parens,inv_freq[0] ≈ 1/base ≠ 1and your positional encoding is hilariously wrong. This bit us at L1 debug; the model produced sensible token 0 (RoPE doesn't matter at pos 0) and garbage from token 1 onward.
Qwen3Attention — quirks
q = self.q_proj(x).view(T, H_q, D)
k = self.k_proj(x).view(T, H_kv, D)
v = self.v_proj(x).view(T, H_kv, D)
# Qwen3-specific: per-head RMSNorm BEFORE rope
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rope(q, k, cos[positions], sin[positions])
# write to contiguous KV (L1) — paged in L2
K_cache[layer, kv_write_offset : kv_write_offset + T] = k
V_cache[layer, kv_write_offset : kv_write_offset + T] = v
K = K_cache[layer, :kv_valid_len]
V = V_cache[layer, :kv_valid_len]
# GQA: repeat KV heads to match Q heads
if H_q != H_kv:
K = K.repeat_interleave(H_q // H_kv, dim=1)
V = V.repeat_interleave(H_q // H_kv, dim=1)
# is_causal only when T == kv_valid_len (prefill, no prior cache)
is_prefill = (T == kv_valid_len)
out = F.scaled_dot_product_attention(
Q.transpose(0,1).unsqueeze(0), K.transpose(0,1).unsqueeze(0), V.transpose(0,1).unsqueeze(0),
is_causal=is_prefill, dropout_p=0.0,
).squeeze(0).transpose(0,1)
1.5 Q&A from session
student question
Why is is_prefill = (T == kv_valid_len) true?
It's true by caller convention. Two cases:
| step | T (queries this call) | kv_valid_len (total cached) | is_prefill |
|---|---|---|---|
| prefill of N-token prompt | N | N (we just wrote N) | true |
| decode step i | 1 | N + i | false |
So T == kv_valid_len is shorthand for "this batch's queries cover the entire history". When it's true, SDPA's built-in causal mask is the right mask. When it's false, the new token attends to all history (no mask).
student push-back
"This only works because we pass kv_valid_len = T during prefill. It's a coupling smell."
Correct instinct. The cleaner formulation is per-sequence: q_len_i == kv_len_i. We'll switch to that representation in L3 once cu_seqlens/seq_lens exist. For L1's contiguous, single-sequence case, the boolean is fine and it removes one parameter from the forward signature.
student question How does the KV cache actually help? What computation did we save?
At decode step i with history length L = N + i, without the cache, you'd recompute K and V for all L positions (because each position's K, V depends only on its hidden state at that layer, but you'd have to recompute the hidden states too). With the cache:
- Q/K/V projections only run on 1 token (the new one), not
L. - Attention reads K/V for
Ltokens but only multiplies them against 1 query row. - MLP runs on 1 token instead of
L.
FLOPs/token/layer for Qwen3-8B (4096 hidden, 12288 MLP, 32 heads, 128 head_dim):
Per-token compute (decode with cache): ~234 MFLOPs / token / layer
Per-step recompute (no cache, L=1024): ~234 MFLOPs * 1024 ≈ 240 GFLOPs / step / layer
Speedup at L=1024: ~1024× (roughly: linear in context length)
The KV cache turns generation from O(L²) per step into O(L) per step. The whole reason "fast inference" is even feasible.
1.6 The smoke harness
Loading both your model and HF simultaneously OOMs on a 24-32 GiB GPU (Qwen3-8B bf16 ≈ 16 GiB each). Load sequentially:
def run_ours(prompt_ids):
model = load_model(MODEL_DIR); model.eval()
out = our_greedy(model, prompt_ids, N_NEW).cpu()
del model
gc.collect(); torch.cuda.empty_cache(); torch.cuda.synchronize()
return out
def run_hf(prompt_ids):
hf = AutoModelForCausalLM.from_pretrained(MODEL_DIR, torch_dtype=torch.bfloat16).cuda().eval()
hf_out = hf.generate(prompt_ids[None], do_sample=False, max_new_tokens=N_NEW)
out = hf_out[0, prompt_ids.numel():].cpu()
del hf, hf_out
gc.collect(); torch.cuda.empty_cache(); torch.cuda.synchronize()
return out
1.7 Acceptance
L1 pass criteriaour_ids == hf_idsbit-equal for 20 greedy tokens on"The capital of France is". Expected:[12095, 13, 576, 6722, 315, 9625, 374, 12095, 13, 576, 6722, 315, 9625, 374, 12095, 13, 576, 6722, 315, 9625]→ "Paris. The capital of France is Paris. The capital of France is Paris. The capital of France".
1.8 Debug log (real bugs we hit)
- RoPE precompute parens. See 1.4 above. Symptom: token 0 OK, all later tokens wrong. Lesson: when math drops to fp32 from a paper, parenthesize the exponent.
- RMSNorm cast order. Multiplying
weight.to(fp32) * x32and then casting back is not the same asweight * x32.to(in_dtype)for HF parity. Match HF's order. - VRAM blowup. Sequential loading required.