Skip to content

Commit 3ffcf87

Browse files
committed
Revert "wip(cute): B-fix attempt — consume-gate DCE + post-attn-LN dispatch ops"
This reverts commit 514b88c.
1 parent 514b88c commit 3ffcf87

3 files changed

Lines changed: 14 additions & 294 deletions

File tree

vllm/nvllm/models/qwen3_5.py

Lines changed: 14 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -463,50 +463,17 @@ def _ct_mark(label: str) -> None:
463463
positions=positions,
464464
)
465465
_ct_mark("self_attn")
466-
# 2026-04-26 (B-fix): the prior `if getattr(impl, "_fusion_active",
467-
# False)` Python-bool gate was dead-eliminated by torch.compile —
468-
# at trace time `_fusion_active` was False (impl __init__ default),
469-
# so dynamo specialised the if-branch as dead and the captured
470-
# graph always ran the else fall-through. Empirically verified
471-
# via /root/.cache/vllm/torch_compile_cache/<hash>/.../
472-
# computation_graph.py: the consume `.copy_()` calls were absent
473-
# AND the legacy Python o_proj path was always present.
474-
#
475-
# Replace with an opaque op (cute_attn_consume) that always runs
476-
# in the captured graph and dispatches at runtime via
477-
# `impl._fusion_active_signal` (a 0-dim int32 tensor mutated
478-
# inside the unified_attention opaque op, where dynamo can't
479-
# see the change). When the signal == 0, the op no-ops; when
480-
# it's > 0 (β-coop fired), it copies β-coop's outputs into
481-
# self_attention_output and residual.
482-
#
483-
# residual_buf and gate_buf are passed as PHANTOM inputs: they
484-
# are not used inside the op body, but their presence forces
485-
# a data dependency on cute_residual_mirror's output, which
486-
# otherwise gets DCE'd despite mutates_args (verified: only
487-
# mutates_args is NOT enough to survive DCE if no graph op
488-
# reads the mutated tensor).
489-
if impl is not None and getattr(impl, "_fusion_bound", False):
490-
# Consistent with the cute_residual_mirror gate above
491-
# (_fusion_bound is set in attach_fusion, stable at trace
492-
# time — dynamo's specialization on it is correct because
493-
# it's a one-time setup flag, not a per-step runtime flag).
494-
# The op uses _CUTE_ATTN_REGISTRY[layer_name] internally to
495-
# read impl._phase_e_use_beta_coop at runtime — Python attr
496-
# access only, no .item() / no CUDA sync, safe under graph
497-
# capture (verified failure mode 2026-04-26 from .item():
498-
# cudaErrorStreamCaptureInvalidated).
499-
torch.ops.vllm.cute_attn_consume(
500-
self_attention_output,
501-
residual,
502-
impl.rmsnorm_output,
503-
impl.residual_output,
504-
impl.residual_buf,
505-
impl.gate_buf,
506-
self.self_attn.attn.layer_name,
507-
)
508-
hidden_states = self_attention_output
509-
_ct_mark("attn_consume_or_legacy")
466+
if impl is not None and getattr(impl, "_fusion_active", False):
467+
# Kernel already did gate*attn, W_O GEMV, residual+RMSNorm.
468+
self_attention_output[:nat].copy_(impl.rmsnorm_output[:nat])
469+
if nat < num_tokens:
470+
self_attention_output[nat:].zero_()
471+
residual[:nat].copy_(impl.residual_output[:nat])
472+
hidden_states = self_attention_output
473+
_ct_mark("attn_consume")
474+
else:
475+
hidden_states = self_attention_output
476+
_ct_mark("attn_legacy")
510477
else:
511478
raise ValueError("Invalid layer_type")
512479

@@ -520,35 +487,13 @@ def _ct_mark(label: str) -> None:
520487
self.attn_layer_scale.to(hidden_states.dtype) + 1
521488
)
522489

523-
# 2026-04-26 (B-fix): post_attn_LN dispatch via opaque op for full
524-
# attention layers (replacing the dead-eliminated Python-bool gate).
525-
# The prior `if not getattr(impl, "_fusion_active", False)` was
526-
# specialised to always-run by dynamo (because trace-time
527-
# `_fusion_active = False`, so `not False = True`). The captured
528-
# graph ran post_attn_LN unconditionally — fine in dual-fire (β-coop's
529-
# rmsnorm_output was unused anyway, Python pipeline did the work)
530-
# but in solo it operated over uninitialised self_attention_output
531-
# because β-coop doesn't expose Phase A to the framework `output`
532-
# parameter.
533-
#
534-
# Linear-attention layers have impl=None and no fusion signal, so
535-
# they keep the plain Python module call (no compile fragility there
536-
# — the dead-elim only bites paths that depend on a runtime-mutated
537-
# Python attribute).
538-
if impl is not None and getattr(impl, "_fusion_bound", False):
539-
torch.ops.vllm.cute_post_attn_ln_dispatch(
540-
hidden_states,
541-
residual,
542-
self.post_attention_layernorm.weight,
543-
float(self.post_attention_layernorm.variance_epsilon),
544-
self.self_attn.attn.layer_name,
545-
)
546-
_ct_mark("post_attn_ln_dispatch")
547-
else:
490+
if not getattr(impl, "_fusion_active", False):
548491
hidden_states, residual = self.post_attention_layernorm(
549492
hidden_states, residual
550493
)
551494
_ct_mark("post_attn_ln")
495+
else:
496+
_ct_mark("post_attn_skip")
552497

553498
# Phase E β-lite consume. When the CuTe backend launched the
554499
# β-lite dispatch inside its forward, the MLP kernel's ε epilogue

vllm/v1/attention/backends/cute_paged/_backend.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -336,20 +336,6 @@ def _preallocate_fusion_buffers(
336336
max_num_seqs, hidden_dim, dtype=torch.bfloat16, device=device
337337
)
338338

339-
# 2026-04-26 (B-fix): runtime signal tensor for the consume-or-postln
340-
# dispatch in qwen3_5.py. 0-dim int32 written inside impl.forward
341-
# (which is wrapped in the unified_attention opaque op, invisible
342-
# to dynamo). Read at runtime by `cute_attn_consume` and
343-
# `cute_post_attn_ln_dispatch` ops via .item(). Value: 0 = fusion
344-
# didn't fire (run Python o_proj/post_attn_LN normally), N > 0 =
345-
# fusion fired with N tokens (use β-coop outputs, skip post_attn_LN).
346-
# This replaces the dead-eliminated `getattr(impl, "_fusion_active",
347-
# False)` Python-bool gates with a tensor-based signal that survives
348-
# torch.compile specialization.
349-
self._fusion_active_signal = torch.zeros(
350-
(), dtype=torch.int32, device=device
351-
)
352-
353339
# Phase D MLP fusion buffers. Shape-defining axes (`slice_ctas`
354340
# for `mlp_partial_fp32`, `num_k_tiles` for `mlp_arrival_count`)
355341
# are both kernel-side constants resolved inside
@@ -669,15 +655,6 @@ def _resolve_fusion_weights(self) -> None:
669655
return
670656

671657
self._fusion_bound = True
672-
# 2026-04-26 (B-fix): register self in the attn-consume registry so
673-
# cute_attn_consume / cute_post_attn_ln_dispatch can look up the impl
674-
# at runtime via layer_name string. Avoids passing impl as a custom-op
675-
# arg (not supported) AND avoids reading a 0-dim tensor signal via
676-
# .item() (causes cudaErrorStreamCaptureInvalidated under graph capture).
677-
from vllm.v1.attention.backends.cute_paged._mlp_op import (
678-
_CUTE_ATTN_REGISTRY,
679-
)
680-
_CUTE_ATTN_REGISTRY[self._fusion_prefix] = self
681658
logger.info(
682659
"CuTe fusion resolved: layer=%s wo_weight=%s rmsnorm_gamma=%s",
683660
self._fusion_prefix,
@@ -1043,23 +1020,6 @@ def forward(
10431020
fits_buffer = num_actual_tokens <= getattr(self, "_fusion_max_num_seqs", 0)
10441021
self._fusion_active = self._fusion_bound and is_decode_only and fits_buffer
10451022
use_fusion = self._fusion_active
1046-
# 2026-04-26 (B-fix): per-step reset for the consume gate. Both flags
1047-
# are read inside opaque op bodies (cute_attn_consume and
1048-
# cute_post_attn_ln_dispatch) at runtime via Python attribute access
1049-
# — keyed off impl from _CUTE_ATTN_REGISTRY by layer_name. Resetting
1050-
# here ensures the gate reflects THIS forward call (β-coop may not
1051-
# fire even when fusion is bound, e.g. predicate fails or kernel
1052-
# falls back to β-lite/paged via the except handler below).
1053-
#
1054-
# _fusion_active_signal stays as a 0-dim tensor for debug visibility
1055-
# but is NOT read inside the consume ops anymore — switching to
1056-
# Python attr access avoids the .item() host-device sync that broke
1057-
# CUDA graph capture (cudaErrorStreamCaptureInvalidated 2026-04-26).
1058-
# Kept commented-out for the moment so the .fill_() side effect can
1059-
# be re-enabled if a future debug session wants the visibility back.
1060-
self._phase_e_use_beta_coop = False
1061-
# if hasattr(self, "_fusion_active_signal"):
1062-
# self._fusion_active_signal.fill_(0)
10631023
# --- PHASE D2 DISABLED (commented, not deleted — Phase B/C debug may
10641024
# need this reset back) ---
10651025
# Pre-D2, the MLP fusion launch was an attention-side side effect
@@ -1350,14 +1310,6 @@ def forward(
13501310
)
13511311
self._phase_e_consumed = True
13521312
self._phase_e_use_beta_coop = True
1353-
# 2026-04-26 (B-fix): the consume gate now reads
1354-
# `impl._phase_e_use_beta_coop` (Python attr) inside the
1355-
# opaque op body via _CUTE_ATTN_REGISTRY lookup — no .item()
1356-
# call, no host-device sync, CUDA-graph-safe. The tensor
1357-
# signal `.fill_(nat)` below is kept commented (not deleted)
1358-
# so it can be re-enabled if a future debug session wants
1359-
# tensor-side visibility into β-coop firing decisions.
1360-
# self._fusion_active_signal.fill_(nat)
13611313
# 2026-04-26: ENV-GATED dump for off-line math verification.
13621314
# CUTE_DUMP_TENSORS=1 enables; bounded to first 3 decode
13631315
# steps × 16 full-attn layers so disk doesn't bloat. Files

vllm/v1/attention/backends/cute_paged/_mlp_op.py

Lines changed: 0 additions & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,6 @@
4646
# The op body reads from this dict at runtime (not at trace time).
4747
_CUTE_MLP_REGISTRY: dict[str, "CutePagedAttentionImpl"] = {}
4848

49-
# 2026-04-26 (B-fix): attn-consume registry, populated by
50-
# `CutePagedAttentionImpl.attach_fusion`. Same impl object as
51-
# _CUTE_MLP_REGISTRY but keyed by ATTENTION layer name (e.g.
52-
# `language_model.model.layers.3.self_attn.attn`), not the MLP key
53-
# used by cute_phase_e_dispatch. Allows cute_attn_consume and
54-
# cute_post_attn_ln_dispatch to look up the impl and read its
55-
# Python-side flags at runtime — avoids the .item() host-device sync
56-
# on a 0-dim tensor signal (which raises cudaErrorStreamCaptureInvalidated
57-
# under CUDA graph capture, verified 2026-04-26).
58-
_CUTE_ATTN_REGISTRY: dict[str, "CutePagedAttentionImpl"] = {}
59-
6049

6150
def _cute_mlp_forward_impl(
6251
x: torch.Tensor,
@@ -320,169 +309,3 @@ def _cute_residual_mirror_fake(
320309
mutates_args=["residual_buf"],
321310
fake_impl=_cute_residual_mirror_fake,
322311
)
323-
324-
325-
# --- 2026-04-26: cute_attn_consume + cute_post_attn_ln_dispatch ----------------
326-
# B-fix: replace the dead-eliminated Python `if _fusion_active` consume branch
327-
# at qwen3_5.py:466-476 and the dead-eliminated `if not _fusion_active`
328-
# post_attention_layernorm gate at qwen3_5.py:490-496.
329-
#
330-
# WHY needed: the captured FX graph (verified 2026-04-26 via
331-
# /root/.cache/vllm/torch_compile_cache/<hash>/rank_0_0/backbone/computation_graph.py)
332-
# specialized BOTH gates at trace time on `_fusion_active = False` (the impl's
333-
# __init__ default) — dynamo can't see the runtime mutation that happens inside
334-
# the unified_attention opaque op. Result: the consume copy was DCE'd, the
335-
# legacy Python o_proj + post_attn_LN ALWAYS ran, β-coop's rmsnorm_output /
336-
# residual_output were never read by the captured graph. In dual-fire this
337-
# happened to produce coherent output because paged populated `output` with
338-
# Phase A and the Python pipeline applied o_proj + post_attn_LN over it. In
339-
# solo (paged gated off, β-coop only), `output` stayed uninitialised and
340-
# Python applied o_proj over junk → gibberish.
341-
#
342-
# Fix: route the consume / postln decision through a runtime tensor signal
343-
# (`impl._fusion_active_signal`, 0-dim int32) that's mutated INSIDE the
344-
# unified_attention op (invisible to dynamo's specialization) and read at
345-
# runtime via .item() inside these opaque ops. Both ops always run, dispatch
346-
# at runtime via the signal value:
347-
# signal == 0 : non-fusion mode (β-coop didn't fire). consume no-ops;
348-
# postln applies the fused-residual RMSNorm in-place over
349-
# the Python o_proj's wo_out.
350-
# signal > 0 : fusion mode (β-coop fired with N=signal tokens). consume
351-
# copies β-coop's rmsnorm_output → self_attention_output and
352-
# residual_output → residual; postln no-ops (β-coop's Phase
353-
# 1C already produced LN(post_input_LN_residual + wo_out)·γ).
354-
#
355-
# residual_buf and gate_buf are passed to consume as PHANTOM inputs (not
356-
# read inside the body) — their sole purpose is to give the cute_residual_mirror
357-
# and cute_residual_mirror(gate_buf, ...) ops observable downstream readers
358-
# in the captured graph, which prevents dynamo's DCE from dropping them
359-
# (verified empirically that mutates_args alone is NOT sufficient against
360-
# DCE — the ops were dead-eliminated despite mutates_args=["residual_buf"]
361-
# until a downstream reader was added).
362-
363-
364-
def _cute_attn_consume_impl(
365-
self_attention_output: torch.Tensor, # mutated [num_tokens, hidden_dim] BF16
366-
residual: torch.Tensor, # mutated [num_tokens, hidden_dim] BF16
367-
rmsnorm_output: torch.Tensor, # impl.rmsnorm_output [max_num_seqs, hidden_dim] BF16
368-
residual_output: torch.Tensor, # impl.residual_output [max_num_seqs, hidden_dim] BF16
369-
residual_buf: torch.Tensor, # phantom for cute_residual_mirror dep
370-
gate_buf: torch.Tensor, # phantom for gate-mirror dep
371-
layer_name: str, # registry key into _CUTE_ATTN_REGISTRY
372-
) -> None:
373-
"""If β-coop fired this step: copy its outputs into model-side tensors.
374-
375-
Reads `impl._phase_e_use_beta_coop` (Python attr) at runtime via
376-
`_CUTE_ATTN_REGISTRY[layer_name]` — no .item() call, no CUDA sync,
377-
safe under CUDA graph capture. Reset to False at top of impl.forward,
378-
set to True only on successful β-coop launch — so True ⇔ β-coop wrote
379-
rmsnorm_output and residual_output for THIS forward call.
380-
"""
381-
impl = _CUTE_ATTN_REGISTRY.get(layer_name)
382-
# 2026-04-26 (B-fix v2): gate on `_fusion_bound` (set once at
383-
# attach_fusion, stable across warmup + runtime) rather than
384-
# `_phase_e_use_beta_coop` (set per-step inside impl.forward — not
385-
# consistently True at warmup capture time, so the captured segment
386-
# would skip the consume kernels and replay would never fill
387-
# self_attention_output from β-coop's outputs). With _fusion_bound:
388-
# capture always sees True for fusion-bound full-attn layers,
389-
# consume kernels always captured. Cost: if β-coop ever fails to
390-
# fire at runtime (e.g. predicate fails), consume reads stale
391-
# impl.rmsnorm_output. Mitigated by the predicate hard-gate landed
392-
# in the prior commit which prevents silent β-coop fallthrough on
393-
# cooperative-launch-too-large.
394-
if impl is None or not getattr(impl, "_fusion_bound", False):
395-
# Non-fusion / non-bound: leave self_attention_output as-is (Python
396-
# o_proj already wrote it) and residual untouched.
397-
return
398-
# Fusion mode: β-coop's Phase 1C produced these. Bound by buffer capacity
399-
# defensively (matches the original Python consume branch).
400-
nat = min(self_attention_output.shape[0], rmsnorm_output.shape[0])
401-
self_attention_output[:nat].copy_(rmsnorm_output[:nat])
402-
if nat < self_attention_output.shape[0]:
403-
# Match the prior `if nat < num_tokens: self_attention_output[nat:].zero_()`
404-
# — keeps unused rows deterministic across decode steps.
405-
self_attention_output[nat:].zero_()
406-
residual[:nat].copy_(residual_output[:nat])
407-
408-
409-
def _cute_attn_consume_fake(
410-
self_attention_output: torch.Tensor,
411-
residual: torch.Tensor,
412-
rmsnorm_output: torch.Tensor,
413-
residual_output: torch.Tensor,
414-
residual_buf: torch.Tensor,
415-
gate_buf: torch.Tensor,
416-
layer_name: str,
417-
) -> None:
418-
return
419-
420-
421-
direct_register_custom_op(
422-
op_name="cute_attn_consume",
423-
op_func=_cute_attn_consume_impl,
424-
# Both self_attention_output and residual are mutated when fusion fires;
425-
# the phantom inputs are read-only.
426-
mutates_args=["self_attention_output", "residual"],
427-
fake_impl=_cute_attn_consume_fake,
428-
)
429-
430-
431-
def _cute_post_attn_ln_dispatch_impl(
432-
hidden_states: torch.Tensor, # mutated [num_tokens, hidden_dim] BF16
433-
residual: torch.Tensor, # mutated [num_tokens, hidden_dim] BF16
434-
weight: torch.Tensor, # post_attention_layernorm.weight [hidden_dim] BF16
435-
rmsnorm_eps: float,
436-
layer_name: str, # registry key into _CUTE_ATTN_REGISTRY
437-
) -> None:
438-
"""If β-coop did NOT fire: apply fused-residual post_attention_layernorm.
439-
440-
Mirrors `_forward_static_with_residual` in vllm/nvllm/layers/layernorm.py:
441-
combined = hidden_states + residual
442-
residual = combined
443-
x = combined.float()
444-
var = x.pow(2).mean(dim=-1, keepdim=True)
445-
x = x * torch.rsqrt(var + eps)
446-
x = x * (1.0 + weight.float())
447-
hidden_states = x.to(combined.dtype)
448-
449-
When β-coop fired, its Phase 1C already produced this exact output into
450-
hidden_states via cute_attn_consume above, and residual already holds
451-
residual_post_attn — skip to avoid double-LN.
452-
453-
Reads `impl._phase_e_use_beta_coop` (Python attr) — no .item() needed,
454-
CUDA-graph-safe. See cute_attn_consume docstring for the gate semantics.
455-
"""
456-
impl = _CUTE_ATTN_REGISTRY.get(layer_name)
457-
# See cute_attn_consume docstring above for why we gate on _fusion_bound
458-
# rather than _phase_e_use_beta_coop. Symmetric: when consume fires,
459-
# post_attn_LN must skip; when consume no-ops, post_attn_LN must apply.
460-
if impl is not None and getattr(impl, "_fusion_bound", False):
461-
# Fusion mode: β-coop already did post_attn_LN. Skip.
462-
return
463-
# Non-fusion mode: replicate _forward_static_with_residual in-place.
464-
combined = hidden_states + residual
465-
residual.copy_(combined)
466-
x = combined.float()
467-
var = x.pow(2).mean(dim=-1, keepdim=True)
468-
x = x * torch.rsqrt(var + rmsnorm_eps)
469-
x = x * (1.0 + weight.float())
470-
hidden_states.copy_(x.to(combined.dtype))
471-
472-
473-
def _cute_post_attn_ln_dispatch_fake(
474-
hidden_states: torch.Tensor,
475-
residual: torch.Tensor,
476-
weight: torch.Tensor,
477-
rmsnorm_eps: float,
478-
layer_name: str,
479-
) -> None:
480-
return
481-
482-
483-
direct_register_custom_op(
484-
op_name="cute_post_attn_ln_dispatch",
485-
op_func=_cute_post_attn_ln_dispatch_impl,
486-
mutates_args=["hidden_states", "residual"],
487-
fake_impl=_cute_post_attn_ln_dispatch_fake,
488-
)

0 commit comments

Comments
 (0)