Skip to content

Commit 54da780

Browse files
Natfiiclaude
andcommitted
refactor(cute): C1.5 — delete Phase 4 + F.1 layer-LN bake plumbing
Per audit Finding 1 and the Q4 self-review, the F.1 layer-LN bake machinery couldn't survive Qwen3.5's stride-4 layer pattern: Phase 4 in-place added mlp_out into residual_output, and the next layer (linear-attn, every 4th layer) doesn't honor the F.1 skip-op — so its input_layernorm re-applied LN over the pre-baked output, corrupting the residual stream. Resolution: per-layer input_layernorm at every decoder layer entry, matching the unfused flow and every surveyed hybrid model (Jamba, Zamba2, Qwen3-Next, Megatron hybrid). β-coop's output is now (mlp_output, residual_output=residual_post_attn); layer N+1's input_layernorm in Python does the residual+mlp accumulation. Deletions: - cute_phase_e_skip_input_layernorm op (_mlp_op.py) - attach_input_layernorm + attach_next_input_layernorm methods commented out (kept commented per feedback_comment_not_delete; C4 fully removes) - _phase_e_skip_next_ln, _input_layernorm_module field inits - Phase 4 ε epilogue from run_beta_coop_full body and from _kernel_phase_0_to_4 JIT (~150 lines removed) - run_beta_coop_full's next_input_layernorm_gamma, next_hidden_output, emit_next_layernorm parameters - attach loops in Qwen3_5Model.__init__ - skip-op call site in Qwen3_5DecoderLayer.forward — replaced with unconditional self.input_layernorm(hidden_states, residual) Cascade fixes (authorized in implementer dispatch): - next_hidden_scratch allocation moved from attach_next_input_layernorm to __init__ — β-lite (kept through C3) still references it - _phase_e_attached gate at _backend.py:1147 rewired from hasattr(_next_input_layernorm_module) to (_phase_e_coop_kernel is not None or _mlp_fusion_bound) - cute_phase_e_dispatch consume branch reads impl.mlp_output[:nat] (was impl.next_hidden_scratch[:nat]) - _next_input_layernorm_module + _emit_next_layernorm field inits KEPT as defensive defaults (β-lite reads via getattr-with-default) Out of scope (kept untouched): - β-lite launch site at _backend.py:1278+ (deletes in C3 with the rest of β-lite) - Standalone Phase 4 launcher (run_phase_4_only, _jit_launch_phase_4_only, _kernel_phase_4_only) at phase_e_kernel.py:2412-2683 — test-only / β-lite-style infra - paged_attention_forward in kernel.py (C2 retires from decode) L3 multi-layer test added at tests/v1/cute_paged/test_uber_kernel_multi_layer.py with 5 source-text assertions covering the deletions and the unconditional input_layernorm regime. Pytest: 7/7 PASS (2 C1 + 5 C1.5). Validation: - Live serve probe with CUTE_PHASE_E_FUSION=1: coherent reasoning output; "The capital of France is" → " Paris, and Paris is located in France, so Paris is" — math fix holds. - gsm8k_eval_50 ≥90% gate DEFERRED to C2: throughput still collapsed at ~0.7 tok/s by the paged_attention_forward + β-coop double-fire Phase A+B+C. C2 retires paged_attention_forward from decode and recovers throughput; gsm8k gate runs there. Diff: 4 modified + 1 new file, -217 net lines. Refs: docs/superpowers/specs/2026-04-25-uber-kernel-migration-design.md docs/research/uber_kernel_migration/spec_audit_2026-04-25.md (Finding 1) docs/research/uber_kernel_migration/q4_brainstorm_layer_LN_2026-04-25.md memory:feedback_layer_output_contract memory:feedback_comment_not_delete Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a65bcef commit 54da780

5 files changed

Lines changed: 354 additions & 457 deletions

File tree

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""L3 multi-layer test: verifies layer-boundary semantics post-C1.5.
2+
3+
Catches:
4+
- Phase 4 not adding mlp_out (audit Finding 1) — layer N+1's input_LN does the sum
5+
- Per-layer input_layernorm fires unconditionally (no skip-op fall-through)
6+
- F.1 layer-LN bake plumbing (skip-op, attach methods, flags) is gone
7+
- run_beta_coop_full no longer takes Phase 4 / next-LN parameters
8+
- cute_phase_e_dispatch consume branch reads mlp_output, not next_hidden_scratch
9+
10+
Strategy: pure source-text inspection via `inspect.getsource`. The full
11+
kernel-level diff is covered by L4 (gsm8k); this test catches the
12+
structural class. No CUDA, no kernel launch — runs anywhere.
13+
"""
14+
import inspect
15+
16+
17+
def test_qwen35_layer_forward_runs_input_layernorm_unconditionally():
18+
"""qwen3_5.py: input_LN gate must collapse to unconditional run.
19+
20+
Post-C1.5 the non-first-layer branch of Qwen3_5DecoderLayer.forward
21+
must call self.input_layernorm(...) directly — no skip-op detour
22+
via cute_phase_e_skip_input_layernorm.
23+
"""
24+
from vllm.nvllm.models import qwen3_5
25+
src = inspect.getsource(qwen3_5.Qwen3_5DecoderLayer.forward)
26+
assert "cute_phase_e_skip_input_layernorm" not in src, (
27+
"F.1 skip-op call site still present in layer forward. "
28+
"Should be deleted in C1.5."
29+
)
30+
assert "self.input_layernorm(hidden_states, residual)" in src, (
31+
"Expected unconditional self.input_layernorm(hidden_states, residual) "
32+
"call in non-first-layer branch."
33+
)
34+
35+
36+
def test_no_attach_input_layernorm_loops_in_model_init():
37+
"""qwen3_5.py source must drop attach_*_layernorm loops.
38+
39+
Both attach_input_layernorm and attach_next_input_layernorm loops
40+
are gone — the F.1 cross-layer bake plumbing they enabled is gone.
41+
42+
We grep the file directly because Qwen3_5Model.__init__ is replaced
43+
by @support_torch_compile, so inspect.getsource(Qwen3_5Model.__init__)
44+
returns the wrapper, not the class body.
45+
"""
46+
from vllm.nvllm.models import qwen3_5
47+
with open(qwen3_5.__file__, "r") as f:
48+
src = f.read()
49+
assert "attach_input_layernorm" not in src, (
50+
"attach_input_layernorm reference still present in qwen3_5.py. "
51+
"C1.5 must delete the attach loop and any Phase F.1 plumbing. "
52+
"(check for impl.attach_input_layernorm(...) call in Qwen3_5Model.__init__)"
53+
)
54+
assert "attach_next_input_layernorm" not in src, (
55+
"attach_next_input_layernorm reference still present in qwen3_5.py. "
56+
"C1.5 must delete the attach loop and any Phase F.1 plumbing. "
57+
"(check for impl.attach_next_input_layernorm(...) call in Qwen3_5Model.__init__)"
58+
)
59+
60+
61+
def test_skip_op_deleted():
62+
"""cute_phase_e_skip_input_layernorm op must be deleted entirely.
63+
64+
Both the impl/fake functions and the direct_register_custom_op
65+
registration must be gone from _mlp_op.py.
66+
"""
67+
from vllm.v1.attention.backends.cute_paged import _mlp_op
68+
src = inspect.getsource(_mlp_op)
69+
assert 'op_name="cute_phase_e_skip_input_layernorm"' not in src, (
70+
"cute_phase_e_skip_input_layernorm op still registered. "
71+
"C1.5 must delete the op registration and the impl/fake functions."
72+
)
73+
74+
75+
def test_phase_4_deleted_from_run_beta_coop_full():
76+
"""Phase 4 args must be dropped from run_beta_coop_full's signature.
77+
78+
The kernel returns at the end of Phase 3 (MLP write). The next-layer
79+
input_LN runs from Python at every layer entry instead of being baked
80+
into the previous layer's epilogue.
81+
"""
82+
from vllm.v1.attention.backends.cute_paged import phase_e_kernel
83+
src = inspect.getsource(
84+
phase_e_kernel.PhaseE_Beta_Kernel.run_beta_coop_full
85+
)
86+
assert "next_input_layernorm_gamma" not in src, (
87+
"Phase 4 arg next_input_layernorm_gamma still present in "
88+
"run_beta_coop_full. C1.5 must drop it."
89+
)
90+
assert "emit_next_layernorm" not in src, (
91+
"Phase 4 arg emit_next_layernorm still present in "
92+
"run_beta_coop_full. C1.5 must drop it."
93+
)
94+
95+
96+
def test_dispatch_op_consumes_mlp_output_not_next_hidden_scratch():
97+
"""cute_phase_e_dispatch consume branch must read mlp_output.
98+
99+
Pre-C1.5 the consume branch read impl.next_hidden_scratch (the
100+
Phase-4-baked next-layer input_LN output). Post-C1.5 it reads
101+
impl.mlp_output (raw post-MLP hidden) and the next layer's
102+
input_LN runs from Python.
103+
"""
104+
from vllm.v1.attention.backends.cute_paged import _mlp_op
105+
src = inspect.getsource(_mlp_op)
106+
assert "next_hidden_scratch" not in src, (
107+
"cute_phase_e_dispatch still references next_hidden_scratch. "
108+
"C1.5 must update consume branch to read from mlp_output."
109+
)
110+
assert "impl.mlp_output[:nat]" in src, (
111+
"Expected cute_phase_e_dispatch consume branch to read "
112+
"impl.mlp_output[:nat] for hidden_out. C1.5 must keep this read "
113+
"active — see _mlp_op.py consume branch."
114+
)

vllm/nvllm/models/qwen3_5.py

Lines changed: 14 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -413,32 +413,16 @@ def _ct_mark(label: str) -> None:
413413
_ct_t = _now
414414

415415
if residual is None:
416-
# First-layer case: no residual to add. Phase F.1 skip-op only
417-
# applies when there's a residual + we're past layer 0.
416+
# First-layer case: no residual to add.
418417
residual = hidden_states
419418
hidden_states = self.input_layernorm(hidden_states)
420419
_ct_mark("input_ln_first")
421420
else:
422-
# Phase F.1: use opaque skip op if MLP fusion is attached on
423-
# THIS layer (attach-time constant, trace-safe). Op body reads
424-
# impl._phase_e_skip_next_ln at runtime → passes through when
425-
# the previous layer's β ε epilogue already ran input_layernorm.
426-
_mlp_layer_name = getattr(self.mlp, "_cute_layer_name", None)
427-
if _mlp_layer_name is not None:
428-
out_x = torch.empty_like(hidden_states)
429-
out_residual = torch.empty_like(residual)
430-
_ct_mark("ln_skip_alloc")
431-
torch.ops.vllm.cute_phase_e_skip_input_layernorm(
432-
hidden_states, residual, out_x, out_residual,
433-
_mlp_layer_name,
434-
)
435-
hidden_states, residual = out_x, out_residual
436-
_ct_mark("ln_skip_op")
437-
else:
438-
hidden_states, residual = self.input_layernorm(
439-
hidden_states, residual
440-
)
441-
_ct_mark("input_ln")
421+
# C1.5: Phase F.1 skip-op deleted. The previous layer's β-coop
422+
# kernel ends at Phase 3 (no input_LN bake), so every layer
423+
# entry runs input_layernorm unconditionally.
424+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
425+
_ct_mark("input_ln")
442426

443427
# Impl decides fusion per-forward. We mirror residual into impl's
444428
# persistent buffer unconditionally when fusion could run (full
@@ -624,56 +608,14 @@ def get_layer(prefix: str):
624608
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
625609
)
626610

627-
# Phase F.1 cross-layer binding (always-on when MLP fusion is
628-
# attached). Cheap module-ref attach; cute_phase_e_skip_input_layernorm
629-
# opaque op needs this present even when β kernels are disabled,
630-
# because the op call site fires whenever _cute_layer_name is set
631-
# (gated by CUTE_MLP_FUSION, not CUTE_PHASE_E_FUSION). Without
632-
# this attach, the op's non-skip branch raises fail-loud.
633-
import os
634-
layer_types = config.layer_types
635-
num_layers = config.num_hidden_layers
636-
for idx, layer in enumerate(self.layers):
637-
if idx < self.start_layer or idx >= self.end_layer:
638-
continue
639-
if layer_types[idx] != "full_attention":
640-
continue
641-
attn = getattr(layer.self_attn, 'attn', None)
642-
impl = getattr(attn, 'impl', None)
643-
if impl is None or not hasattr(impl, 'attach_input_layernorm'):
644-
continue
645-
impl.attach_input_layernorm(
646-
getattr(layer, 'input_layernorm', None)
647-
)
648-
649-
# Phase E cross-layer binding (gated): every fusion-active
650-
# (full_attention) decoder layer receives a ref to the NEXT decoder
651-
# layer's input_layernorm module. Last layer (idx 63) passes None
652-
# so the β kernel's ε epilogue omits the next-layer norm pull.
653-
# ALSO allocates β kernel scratch buffers (heavy), so this stays
654-
# gated by CUTE_PHASE_E_FUSION.
655-
# Spec: docs/superpowers/specs/2026-04-22-unreal-kernel-phase-e-d25-design.md §5.3
656-
if os.environ.get("CUTE_PHASE_E_FUSION", "0") == "1":
657-
for idx, layer in enumerate(self.layers):
658-
if idx < self.start_layer or idx >= self.end_layer:
659-
continue
660-
if layer_types[idx] != "full_attention":
661-
continue
662-
# impl lives on the inner Attention module, not on the
663-
# Qwen3_5Attention wrapper: Qwen3_5Attention.attn is
664-
# Attention, Attention.impl is CutePagedAttentionImpl.
665-
# Existing pattern: see self_attn.attn.impl at L243, 361, 395.
666-
attn = getattr(layer.self_attn, 'attn', None)
667-
impl = getattr(attn, 'impl', None)
668-
if impl is None or not hasattr(impl, 'attach_next_input_layernorm'):
669-
continue # non-CuTe backend
670-
# getattr tolerates PPMissingLayer (no input_layernorm attr)
671-
next_norm = (
672-
getattr(self.layers[idx + 1], 'input_layernorm', None)
673-
if idx + 1 < num_layers
674-
else None
675-
)
676-
impl.attach_next_input_layernorm(next_norm)
611+
# C1.5: Phase F.1 cross-layer binding loops (per-layer + next-layer
612+
# LN bake) deleted. The skip-op they enabled (cute_phase_e_skip_*)
613+
# was permanently retired in C1.5 along with β-coop's Phase 4
614+
# epilogue — every layer now runs input_layernorm unconditionally
615+
# at layer entry from Python (see Qwen3_5DecoderLayer.forward).
616+
# The corresponding attach_* methods on CutePagedAttentionImpl
617+
# are commented-out (not deleted) in _backend.py per the
618+
# comment-out-kernel-code rule.
677619

678620
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
679621
["hidden_states", "residual"], config.hidden_size

0 commit comments

Comments
 (0)