Skip to content

Commit a65bcef

Browse files
Natfiiclaude
andcommitted
fix(cute): C1 — β-coop and β-lite read residual_buf, not residual_output
β-coop's Phase 1C residual_in pointed at self.residual_output, which paged_attention_forward had already filled with (h+r) + wo_out = residual_post_attn. β-coop then re-added wo_out inside its own Phase 1C, producing 2·wo_out + h + r — gibberish output cascading through 16 fused full-attn layers, observed as " 2 ". Same alias existed in β-lite's residual_post_ln source (audit Finding 6; β-lite never re-ran Phase C so the corruption only manifested when β-coop fired, but β-lite was structurally on the same buggy path). Fixed both call sites: - vllm/v1/attention/backends/cute_paged/_backend.py:1175 (β-coop) - vllm/v1/attention/backends/cute_paged/_backend.py:1268 (β-lite) Both now read self.residual_buf — the post-input-LN residual mirrored from qwen3_5.py:460 — matching the math the kernels expect. L2 buffer-contracts test added at tests/v1/cute_paged/test_uber_kernel_buffer_contracts.py. Pure source-text inspection via inspect.getsource on CutePagedAttentionImpl.forward; catches the class structurally without requiring a GPU run. Validation: - Pre-fix pytest: 2 FAILED (test caught the bug) - Post-fix pytest: 2 PASSED - Live serve probe with CUTE_PHASE_E_FUSION=1 produced coherent reasoning output (not pre-fix " 2 ..." gibberish). gsm8k_eval_50 ≥90% gate DEFERRED to C2. At this commit's state β-coop and paged_attention_forward both fire Phase A+B+C, costing ~+15 ms per fused-full-attn layer × 16 layers ≈ 0.7 tok/s observed (predicted by memory:project_phase_e_phantom_speedup). The 180s per-question timeout in scripts/gsm8k_eval_50.py can't accommodate. C2 retires paged_attention_forward from the decode path and recovers throughput; the gsm8k gate runs there. Refs: docs/superpowers/specs/2026-04-25-uber-kernel-migration-design.md docs/research/uber_kernel_migration/spec_audit_2026-04-25.md (Finding 6) memory:project_phase_e_beta_math_bug memory:project_phase_e_phantom_speedup Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 2b21f34 commit a65bcef

2 files changed

Lines changed: 50 additions & 2 deletions

File tree

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""L2 structural test: verifies β-coop / β-lite read inputs from the right buffer.
2+
3+
Pre-fix: β-coop reads `self.residual_output` (post-Phase-C output of the legacy
4+
paged_attention_forward), causing residual_post_attn = 2*attn_out + h + r.
5+
Post-fix: β-coop reads `self.residual_buf` (post-input-LN residual mirrored
6+
from qwen3_5.py:460), giving residual_post_attn = attn_out + h + r.
7+
8+
Strategy: pure source-text inspection via `inspect.getsource` on
9+
`CutePagedAttentionImpl.forward`. We assert the post-fix wiring is present
10+
(`self.residual_buf`) and the buggy alias (`self.residual_output` as residual
11+
input to β kernels) is absent. No CUDA, no kernel launch — runs anywhere.
12+
"""
13+
import inspect
14+
import pytest
15+
16+
17+
def test_beta_coop_residual_in_sources_from_residual_buf():
18+
"""β-coop's residual_in must source from self.residual_buf, not residual_output."""
19+
from vllm.v1.attention.backends.cute_paged._backend import (
20+
CutePagedAttentionImpl,
21+
)
22+
23+
src = inspect.getsource(CutePagedAttentionImpl.forward)
24+
assert "residual_in=self.residual_buf" in src, (
25+
"Expected β-coop launch to read from self.residual_buf; found a different source. "
26+
"Check _backend.py:1175 — buffer-aliasing bug may have regressed."
27+
)
28+
# Strengthened guard (audit Finding 6 / option b): C1 fixes both occurrences
29+
# of the alias bug, so `residual_in=self.residual_output` must not appear
30+
# ANYWHERE in CutePagedAttentionImpl.forward source. The original anchor
31+
# ("# β-coop") doesn't exist in source, so the guarded form silently passed
32+
# either way. This bare check fails loudly if the bug regresses.
33+
assert "residual_in=self.residual_output" not in src, (
34+
"β-coop call site still reads self.residual_output — the alias bug is back. "
35+
"See _backend.py:1175 (commit 76b88ba21) and audit Finding 6."
36+
)
37+
38+
39+
def test_beta_lite_residual_post_ln_sources_from_residual_buf():
40+
"""β-lite has the same alias bug pre-migration (audit Finding 6). Verify fix."""
41+
from vllm.v1.attention.backends.cute_paged._backend import (
42+
CutePagedAttentionImpl,
43+
)
44+
45+
src = inspect.getsource(CutePagedAttentionImpl.forward)
46+
assert "residual_post_ln=self.residual_buf" in src, (
47+
"β-lite still aliases legacy buffer. See audit Finding 6 / _backend.py:1268."
48+
)

vllm/v1/attention/backends/cute_paged/_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,7 +1172,7 @@ def forward(
11721172
# Phase 0 inputs (dummy — output side-channel for future
11731173
# QKV-fusion; not consumed by this layer's attn path).
11741174
hidden_in=self.rmsnorm_output[:nat],
1175-
residual_in=self.residual_output[:nat],
1175+
residual_in=self.residual_buf[:nat],
11761176
input_gamma=self._phase_e_coop_input_gamma,
11771177
post_attn_gamma=self.rmsnorm_gamma,
11781178
attn_input_bf16=self._phase_e_coop_attn_input_scratch[:nat],
@@ -1265,7 +1265,7 @@ def forward(
12651265
gate_up_global_scale=self._mlp_gate_up_gs,
12661266
down_global_scale=self._mlp_down_gs,
12671267
# ε epilogue inputs (Task 8 kwargs):
1268-
residual_post_ln=self.residual_output[:nat],
1268+
residual_post_ln=self.residual_buf[:nat],
12691269
next_input_layernorm_gamma=_next_gamma,
12701270
next_hidden_output=self.next_hidden_scratch[:nat],
12711271
emit_epilogue=True,

0 commit comments

Comments
 (0)