Skip to content

Commit 514b88c

Browse files
Natfiiclaude
andcommitted
wip(cute): B-fix attempt — consume-gate DCE + post-attn-LN dispatch ops
WIP: partial fix for the C2 migration's consume-gate plumbing problem. This commit will be reverted in the next commit; preserved here in git history for the follow-up architectural pass on feat/uber-kernel-migration. See docs/research/uber_kernel_migration/2026-04-26-consume-gate-dce-and-graph-capture.md (landing in the next commit) for the full diagnostic baseline. What was diagnosed in this session ================================== The C2 migration's premise — β-coop replaces Python o_proj + post_attention_layernorm — was structurally unobservable to torch.compile under PIECEWISE compile. Inspecting the captured FX graph at /root/.cache/vllm/torch_compile_cache/<hash>/rank_0_0/backbone/computation_graph.py revealed: 1. `cute_residual_mirror` was DCE-dropped despite `mutates_args=["residual_buf"]`. Dynamo's DCE removes ops whose mutations have no observable downstream reader IN THE GRAPH; impl.residual_buf is read inside opaque op bodies via Python-attribute access, invisible to dynamo's reachability analysis. `mutates_args` alone is NOT sufficient — needs an explicit graph-input downstream reader. 2. The `if getattr(impl, "_fusion_active", False)` consume gate at qwen3_5.py:466-476 was specialised to "always-take else branch" by dynamo at trace time (`_fusion_active = False` at __init__, mutated inside the unified_attention opaque op where dynamo can't see). Captured graph: legacy Python o_proj + post_attn_LN ALWAYS ran; β-coop's rmsnorm_output / residual_output were never read. 3. Dual-fire happened to produce coherent output entirely by accident: paged populated `output` with Phase A attn (via the framework op's declared mutates_args), Python o_proj computed wo_out from it, Python post_attn_LN reconstructed residual_post_attn. β-coop's outputs were wasted. Solo (paged-skip) broke because nothing populated `output` with Phase A in solo mode. What this commit attempted ========================== Three opaque ops to replace the dead-eliminated Python branches: - `cute_residual_mirror` (existing) — preserved across DCE by passing residual_buf as a phantom input to `cute_attn_consume`, giving the mutation a downstream reader. - `cute_attn_consume` (new) — replaces the dead-eliminated consume branch. Always runs in the captured graph; dispatches at runtime via registry lookup of impl._fusion_bound. When β-coop fired, copies impl.rmsnorm_output → self_attention_output and impl.residual_output → residual. - `cute_post_attn_ln_dispatch` (new) — replaces the dead-eliminated post_attn_LN gate. Skips when fusion-bound (β-coop did Phase C); applies fused-residual RMSNorm in-place when not. Result matrix ============= | Mode | Result | |-----------------------------------------------|-----------------| | PIECEWISE + cudagraph_mode=NONE + solo | COHERENT ✓ | | PIECEWISE + cudagraph_mode=PIECEWISE + solo | GIBBERISH ✗ | Under PIECEWISE+NONE, the B-fix is correct: solo β-coop produces " Paris. Paris is a city in France..." for the standard probe. Under PIECEWISE+graphs (production target), gibberish: first token " Paris" correct (prefill works), then decode collapses into a single-token loop ("这种现象" repeated). The captured graph contains all 4 ops (cute_residual_mirror, cute_attn_consume, cute_post_attn_ln_dispatch, cute_phase_e_dispatch) but the runtime output is wrong. Failed pivots in this session ============================= - v1: tensor signal `_fusion_active_signal` + `int(signal.item())` inside the op body. Crashed at warmup with `cudaErrorStreamCaptureInvalidated` — `.item()` causes a host-device sync that's incompatible with CUDA graph capture. - v2: registry-lookup of `impl._phase_e_use_beta_coop` (Python attr, per-step reset). Survived capture but produced gibberish. - v3: registry-lookup of `impl._fusion_bound` (set once at attach_fusion, stable across warmup + runtime). Same gibberish. The graph-capture failure under cudagraph_mode=PIECEWISE remains unexplained at the end of this session. Suspected root causes for the follow-up architectural pass: - vLLM V1 captures decode segments at warmup with shapes/state that diverge from runtime; Python-attr reads inside opaque op bodies don't reliably reflect runtime state. - β-coop's cooperative-launch + atomic-counter spin-wait may have CUDA-graph replay quirks independent of the consume gate. - Some interaction between PIECEWISE's segment boundaries and the new opaque ops. Why this is being reverted ========================== The B-fix proves the consume-gate DCE is real and bounded — it works under PIECEWISE+NONE. But shipping a partial fix that fails under the production graph mode would be a regression. The architectural answer (have β-coop write to the framework `output` directly so Python pipeline becomes unnecessary, OR use in-graph torch.cond/torch.where on tensor signals, OR capture multiple graphs and dispatch externally) belongs in the C2 redesign on feat/uber-kernel-migration, not patched on a debug branch. The next commit reverts this. The findings doc lands separately so it remains in HEAD for the follow-up session. Refs: memory:project_beta_coop_residual_solo_bug memory:project_uber_kernel_migration memory:feedback_pace_pressure (don't let pace drive design) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 5a0311c commit 514b88c

3 files changed

Lines changed: 294 additions & 14 deletions

File tree

vllm/nvllm/models/qwen3_5.py

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -463,17 +463,50 @@ def _ct_mark(label: str) -> None:
463463
positions=positions,
464464
)
465465
_ct_mark("self_attn")
466-
if impl is not None and getattr(impl, "_fusion_active", False):
467-
# Kernel already did gate*attn, W_O GEMV, residual+RMSNorm.
468-
self_attention_output[:nat].copy_(impl.rmsnorm_output[:nat])
469-
if nat < num_tokens:
470-
self_attention_output[nat:].zero_()
471-
residual[:nat].copy_(impl.residual_output[:nat])
472-
hidden_states = self_attention_output
473-
_ct_mark("attn_consume")
474-
else:
475-
hidden_states = self_attention_output
476-
_ct_mark("attn_legacy")
466+
# 2026-04-26 (B-fix): the prior `if getattr(impl, "_fusion_active",
467+
# False)` Python-bool gate was dead-eliminated by torch.compile —
468+
# at trace time `_fusion_active` was False (impl __init__ default),
469+
# so dynamo specialised the if-branch as dead and the captured
470+
# graph always ran the else fall-through. Empirically verified
471+
# via /root/.cache/vllm/torch_compile_cache/<hash>/.../
472+
# computation_graph.py: the consume `.copy_()` calls were absent
473+
# AND the legacy Python o_proj path was always present.
474+
#
475+
# Replace with an opaque op (cute_attn_consume) that always runs
476+
# in the captured graph and dispatches at runtime via
477+
# `impl._fusion_active_signal` (a 0-dim int32 tensor mutated
478+
# inside the unified_attention opaque op, where dynamo can't
479+
# see the change). When the signal == 0, the op no-ops; when
480+
# it's > 0 (β-coop fired), it copies β-coop's outputs into
481+
# self_attention_output and residual.
482+
#
483+
# residual_buf and gate_buf are passed as PHANTOM inputs: they
484+
# are not used inside the op body, but their presence forces
485+
# a data dependency on cute_residual_mirror's output, which
486+
# otherwise gets DCE'd despite mutates_args (verified: only
487+
# mutates_args is NOT enough to survive DCE if no graph op
488+
# reads the mutated tensor).
489+
if impl is not None and getattr(impl, "_fusion_bound", False):
490+
# Consistent with the cute_residual_mirror gate above
491+
# (_fusion_bound is set in attach_fusion, stable at trace
492+
# time — dynamo's specialization on it is correct because
493+
# it's a one-time setup flag, not a per-step runtime flag).
494+
# The op uses _CUTE_ATTN_REGISTRY[layer_name] internally to
495+
# read impl._phase_e_use_beta_coop at runtime — Python attr
496+
# access only, no .item() / no CUDA sync, safe under graph
497+
# capture (verified failure mode 2026-04-26 from .item():
498+
# cudaErrorStreamCaptureInvalidated).
499+
torch.ops.vllm.cute_attn_consume(
500+
self_attention_output,
501+
residual,
502+
impl.rmsnorm_output,
503+
impl.residual_output,
504+
impl.residual_buf,
505+
impl.gate_buf,
506+
self.self_attn.attn.layer_name,
507+
)
508+
hidden_states = self_attention_output
509+
_ct_mark("attn_consume_or_legacy")
477510
else:
478511
raise ValueError("Invalid layer_type")
479512

@@ -487,13 +520,35 @@ def _ct_mark(label: str) -> None:
487520
self.attn_layer_scale.to(hidden_states.dtype) + 1
488521
)
489522

490-
if not getattr(impl, "_fusion_active", False):
523+
# 2026-04-26 (B-fix): post_attn_LN dispatch via opaque op for full
524+
# attention layers (replacing the dead-eliminated Python-bool gate).
525+
# The prior `if not getattr(impl, "_fusion_active", False)` was
526+
# specialised to always-run by dynamo (because trace-time
527+
# `_fusion_active = False`, so `not False = True`). The captured
528+
# graph ran post_attn_LN unconditionally — fine in dual-fire (β-coop's
529+
# rmsnorm_output was unused anyway, Python pipeline did the work)
530+
# but in solo it operated over uninitialised self_attention_output
531+
# because β-coop doesn't expose Phase A to the framework `output`
532+
# parameter.
533+
#
534+
# Linear-attention layers have impl=None and no fusion signal, so
535+
# they keep the plain Python module call (no compile fragility there
536+
# — the dead-elim only bites paths that depend on a runtime-mutated
537+
# Python attribute).
538+
if impl is not None and getattr(impl, "_fusion_bound", False):
539+
torch.ops.vllm.cute_post_attn_ln_dispatch(
540+
hidden_states,
541+
residual,
542+
self.post_attention_layernorm.weight,
543+
float(self.post_attention_layernorm.variance_epsilon),
544+
self.self_attn.attn.layer_name,
545+
)
546+
_ct_mark("post_attn_ln_dispatch")
547+
else:
491548
hidden_states, residual = self.post_attention_layernorm(
492549
hidden_states, residual
493550
)
494551
_ct_mark("post_attn_ln")
495-
else:
496-
_ct_mark("post_attn_skip")
497552

498553
# Phase E β-lite consume. When the CuTe backend launched the
499554
# β-lite dispatch inside its forward, the MLP kernel's ε epilogue

vllm/v1/attention/backends/cute_paged/_backend.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,20 @@ def _preallocate_fusion_buffers(
336336
max_num_seqs, hidden_dim, dtype=torch.bfloat16, device=device
337337
)
338338

339+
# 2026-04-26 (B-fix): runtime signal tensor for the consume-or-postln
340+
# dispatch in qwen3_5.py. 0-dim int32 written inside impl.forward
341+
# (which is wrapped in the unified_attention opaque op, invisible
342+
# to dynamo). Read at runtime by `cute_attn_consume` and
343+
# `cute_post_attn_ln_dispatch` ops via .item(). Value: 0 = fusion
344+
# didn't fire (run Python o_proj/post_attn_LN normally), N > 0 =
345+
# fusion fired with N tokens (use β-coop outputs, skip post_attn_LN).
346+
# This replaces the dead-eliminated `getattr(impl, "_fusion_active",
347+
# False)` Python-bool gates with a tensor-based signal that survives
348+
# torch.compile specialization.
349+
self._fusion_active_signal = torch.zeros(
350+
(), dtype=torch.int32, device=device
351+
)
352+
339353
# Phase D MLP fusion buffers. Shape-defining axes (`slice_ctas`
340354
# for `mlp_partial_fp32`, `num_k_tiles` for `mlp_arrival_count`)
341355
# are both kernel-side constants resolved inside
@@ -655,6 +669,15 @@ def _resolve_fusion_weights(self) -> None:
655669
return
656670

657671
self._fusion_bound = True
672+
# 2026-04-26 (B-fix): register self in the attn-consume registry so
673+
# cute_attn_consume / cute_post_attn_ln_dispatch can look up the impl
674+
# at runtime via layer_name string. Avoids passing impl as a custom-op
675+
# arg (not supported) AND avoids reading a 0-dim tensor signal via
676+
# .item() (causes cudaErrorStreamCaptureInvalidated under graph capture).
677+
from vllm.v1.attention.backends.cute_paged._mlp_op import (
678+
_CUTE_ATTN_REGISTRY,
679+
)
680+
_CUTE_ATTN_REGISTRY[self._fusion_prefix] = self
658681
logger.info(
659682
"CuTe fusion resolved: layer=%s wo_weight=%s rmsnorm_gamma=%s",
660683
self._fusion_prefix,
@@ -1020,6 +1043,23 @@ def forward(
10201043
fits_buffer = num_actual_tokens <= getattr(self, "_fusion_max_num_seqs", 0)
10211044
self._fusion_active = self._fusion_bound and is_decode_only and fits_buffer
10221045
use_fusion = self._fusion_active
1046+
# 2026-04-26 (B-fix): per-step reset for the consume gate. Both flags
1047+
# are read inside opaque op bodies (cute_attn_consume and
1048+
# cute_post_attn_ln_dispatch) at runtime via Python attribute access
1049+
# — keyed off impl from _CUTE_ATTN_REGISTRY by layer_name. Resetting
1050+
# here ensures the gate reflects THIS forward call (β-coop may not
1051+
# fire even when fusion is bound, e.g. predicate fails or kernel
1052+
# falls back to β-lite/paged via the except handler below).
1053+
#
1054+
# _fusion_active_signal stays as a 0-dim tensor for debug visibility
1055+
# but is NOT read inside the consume ops anymore — switching to
1056+
# Python attr access avoids the .item() host-device sync that broke
1057+
# CUDA graph capture (cudaErrorStreamCaptureInvalidated 2026-04-26).
1058+
# Kept commented-out for the moment so the .fill_() side effect can
1059+
# be re-enabled if a future debug session wants the visibility back.
1060+
self._phase_e_use_beta_coop = False
1061+
# if hasattr(self, "_fusion_active_signal"):
1062+
# self._fusion_active_signal.fill_(0)
10231063
# --- PHASE D2 DISABLED (commented, not deleted — Phase B/C debug may
10241064
# need this reset back) ---
10251065
# Pre-D2, the MLP fusion launch was an attention-side side effect
@@ -1310,6 +1350,14 @@ def forward(
13101350
)
13111351
self._phase_e_consumed = True
13121352
self._phase_e_use_beta_coop = True
1353+
# 2026-04-26 (B-fix): the consume gate now reads
1354+
# `impl._phase_e_use_beta_coop` (Python attr) inside the
1355+
# opaque op body via _CUTE_ATTN_REGISTRY lookup — no .item()
1356+
# call, no host-device sync, CUDA-graph-safe. The tensor
1357+
# signal `.fill_(nat)` below is kept commented (not deleted)
1358+
# so it can be re-enabled if a future debug session wants
1359+
# tensor-side visibility into β-coop firing decisions.
1360+
# self._fusion_active_signal.fill_(nat)
13131361
# 2026-04-26: ENV-GATED dump for off-line math verification.
13141362
# CUTE_DUMP_TENSORS=1 enables; bounded to first 3 decode
13151363
# steps × 16 full-attn layers so disk doesn't bloat. Files

vllm/v1/attention/backends/cute_paged/_mlp_op.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,17 @@
4646
# The op body reads from this dict at runtime (not at trace time).
4747
_CUTE_MLP_REGISTRY: dict[str, "CutePagedAttentionImpl"] = {}
4848

49+
# 2026-04-26 (B-fix): attn-consume registry, populated by
50+
# `CutePagedAttentionImpl.attach_fusion`. Same impl object as
51+
# _CUTE_MLP_REGISTRY but keyed by ATTENTION layer name (e.g.
52+
# `language_model.model.layers.3.self_attn.attn`), not the MLP key
53+
# used by cute_phase_e_dispatch. Allows cute_attn_consume and
54+
# cute_post_attn_ln_dispatch to look up the impl and read its
55+
# Python-side flags at runtime — avoids the .item() host-device sync
56+
# on a 0-dim tensor signal (which raises cudaErrorStreamCaptureInvalidated
57+
# under CUDA graph capture, verified 2026-04-26).
58+
_CUTE_ATTN_REGISTRY: dict[str, "CutePagedAttentionImpl"] = {}
59+
4960

5061
def _cute_mlp_forward_impl(
5162
x: torch.Tensor,
@@ -309,3 +320,169 @@ def _cute_residual_mirror_fake(
309320
mutates_args=["residual_buf"],
310321
fake_impl=_cute_residual_mirror_fake,
311322
)
323+
324+
325+
# --- 2026-04-26: cute_attn_consume + cute_post_attn_ln_dispatch ----------------
326+
# B-fix: replace the dead-eliminated Python `if _fusion_active` consume branch
327+
# at qwen3_5.py:466-476 and the dead-eliminated `if not _fusion_active`
328+
# post_attention_layernorm gate at qwen3_5.py:490-496.
329+
#
330+
# WHY needed: the captured FX graph (verified 2026-04-26 via
331+
# /root/.cache/vllm/torch_compile_cache/<hash>/rank_0_0/backbone/computation_graph.py)
332+
# specialized BOTH gates at trace time on `_fusion_active = False` (the impl's
333+
# __init__ default) — dynamo can't see the runtime mutation that happens inside
334+
# the unified_attention opaque op. Result: the consume copy was DCE'd, the
335+
# legacy Python o_proj + post_attn_LN ALWAYS ran, β-coop's rmsnorm_output /
336+
# residual_output were never read by the captured graph. In dual-fire this
337+
# happened to produce coherent output because paged populated `output` with
338+
# Phase A and the Python pipeline applied o_proj + post_attn_LN over it. In
339+
# solo (paged gated off, β-coop only), `output` stayed uninitialised and
340+
# Python applied o_proj over junk → gibberish.
341+
#
342+
# Fix: route the consume / postln decision through a runtime tensor signal
343+
# (`impl._fusion_active_signal`, 0-dim int32) that's mutated INSIDE the
344+
# unified_attention op (invisible to dynamo's specialization) and read at
345+
# runtime via .item() inside these opaque ops. Both ops always run, dispatch
346+
# at runtime via the signal value:
347+
# signal == 0 : non-fusion mode (β-coop didn't fire). consume no-ops;
348+
# postln applies the fused-residual RMSNorm in-place over
349+
# the Python o_proj's wo_out.
350+
# signal > 0 : fusion mode (β-coop fired with N=signal tokens). consume
351+
# copies β-coop's rmsnorm_output → self_attention_output and
352+
# residual_output → residual; postln no-ops (β-coop's Phase
353+
# 1C already produced LN(post_input_LN_residual + wo_out)·γ).
354+
#
355+
# residual_buf and gate_buf are passed to consume as PHANTOM inputs (not
356+
# read inside the body) — their sole purpose is to give the cute_residual_mirror
357+
# and cute_residual_mirror(gate_buf, ...) ops observable downstream readers
358+
# in the captured graph, which prevents dynamo's DCE from dropping them
359+
# (verified empirically that mutates_args alone is NOT sufficient against
360+
# DCE — the ops were dead-eliminated despite mutates_args=["residual_buf"]
361+
# until a downstream reader was added).
362+
363+
364+
def _cute_attn_consume_impl(
365+
self_attention_output: torch.Tensor, # mutated [num_tokens, hidden_dim] BF16
366+
residual: torch.Tensor, # mutated [num_tokens, hidden_dim] BF16
367+
rmsnorm_output: torch.Tensor, # impl.rmsnorm_output [max_num_seqs, hidden_dim] BF16
368+
residual_output: torch.Tensor, # impl.residual_output [max_num_seqs, hidden_dim] BF16
369+
residual_buf: torch.Tensor, # phantom for cute_residual_mirror dep
370+
gate_buf: torch.Tensor, # phantom for gate-mirror dep
371+
layer_name: str, # registry key into _CUTE_ATTN_REGISTRY
372+
) -> None:
373+
"""If β-coop fired this step: copy its outputs into model-side tensors.
374+
375+
Reads `impl._phase_e_use_beta_coop` (Python attr) at runtime via
376+
`_CUTE_ATTN_REGISTRY[layer_name]` — no .item() call, no CUDA sync,
377+
safe under CUDA graph capture. Reset to False at top of impl.forward,
378+
set to True only on successful β-coop launch — so True ⇔ β-coop wrote
379+
rmsnorm_output and residual_output for THIS forward call.
380+
"""
381+
impl = _CUTE_ATTN_REGISTRY.get(layer_name)
382+
# 2026-04-26 (B-fix v2): gate on `_fusion_bound` (set once at
383+
# attach_fusion, stable across warmup + runtime) rather than
384+
# `_phase_e_use_beta_coop` (set per-step inside impl.forward — not
385+
# consistently True at warmup capture time, so the captured segment
386+
# would skip the consume kernels and replay would never fill
387+
# self_attention_output from β-coop's outputs). With _fusion_bound:
388+
# capture always sees True for fusion-bound full-attn layers,
389+
# consume kernels always captured. Cost: if β-coop ever fails to
390+
# fire at runtime (e.g. predicate fails), consume reads stale
391+
# impl.rmsnorm_output. Mitigated by the predicate hard-gate landed
392+
# in the prior commit which prevents silent β-coop fallthrough on
393+
# cooperative-launch-too-large.
394+
if impl is None or not getattr(impl, "_fusion_bound", False):
395+
# Non-fusion / non-bound: leave self_attention_output as-is (Python
396+
# o_proj already wrote it) and residual untouched.
397+
return
398+
# Fusion mode: β-coop's Phase 1C produced these. Bound by buffer capacity
399+
# defensively (matches the original Python consume branch).
400+
nat = min(self_attention_output.shape[0], rmsnorm_output.shape[0])
401+
self_attention_output[:nat].copy_(rmsnorm_output[:nat])
402+
if nat < self_attention_output.shape[0]:
403+
# Match the prior `if nat < num_tokens: self_attention_output[nat:].zero_()`
404+
# — keeps unused rows deterministic across decode steps.
405+
self_attention_output[nat:].zero_()
406+
residual[:nat].copy_(residual_output[:nat])
407+
408+
409+
def _cute_attn_consume_fake(
410+
self_attention_output: torch.Tensor,
411+
residual: torch.Tensor,
412+
rmsnorm_output: torch.Tensor,
413+
residual_output: torch.Tensor,
414+
residual_buf: torch.Tensor,
415+
gate_buf: torch.Tensor,
416+
layer_name: str,
417+
) -> None:
418+
return
419+
420+
421+
direct_register_custom_op(
422+
op_name="cute_attn_consume",
423+
op_func=_cute_attn_consume_impl,
424+
# Both self_attention_output and residual are mutated when fusion fires;
425+
# the phantom inputs are read-only.
426+
mutates_args=["self_attention_output", "residual"],
427+
fake_impl=_cute_attn_consume_fake,
428+
)
429+
430+
431+
def _cute_post_attn_ln_dispatch_impl(
432+
hidden_states: torch.Tensor, # mutated [num_tokens, hidden_dim] BF16
433+
residual: torch.Tensor, # mutated [num_tokens, hidden_dim] BF16
434+
weight: torch.Tensor, # post_attention_layernorm.weight [hidden_dim] BF16
435+
rmsnorm_eps: float,
436+
layer_name: str, # registry key into _CUTE_ATTN_REGISTRY
437+
) -> None:
438+
"""If β-coop did NOT fire: apply fused-residual post_attention_layernorm.
439+
440+
Mirrors `_forward_static_with_residual` in vllm/nvllm/layers/layernorm.py:
441+
combined = hidden_states + residual
442+
residual = combined
443+
x = combined.float()
444+
var = x.pow(2).mean(dim=-1, keepdim=True)
445+
x = x * torch.rsqrt(var + eps)
446+
x = x * (1.0 + weight.float())
447+
hidden_states = x.to(combined.dtype)
448+
449+
When β-coop fired, its Phase 1C already produced this exact output into
450+
hidden_states via cute_attn_consume above, and residual already holds
451+
residual_post_attn — skip to avoid double-LN.
452+
453+
Reads `impl._phase_e_use_beta_coop` (Python attr) — no .item() needed,
454+
CUDA-graph-safe. See cute_attn_consume docstring for the gate semantics.
455+
"""
456+
impl = _CUTE_ATTN_REGISTRY.get(layer_name)
457+
# See cute_attn_consume docstring above for why we gate on _fusion_bound
458+
# rather than _phase_e_use_beta_coop. Symmetric: when consume fires,
459+
# post_attn_LN must skip; when consume no-ops, post_attn_LN must apply.
460+
if impl is not None and getattr(impl, "_fusion_bound", False):
461+
# Fusion mode: β-coop already did post_attn_LN. Skip.
462+
return
463+
# Non-fusion mode: replicate _forward_static_with_residual in-place.
464+
combined = hidden_states + residual
465+
residual.copy_(combined)
466+
x = combined.float()
467+
var = x.pow(2).mean(dim=-1, keepdim=True)
468+
x = x * torch.rsqrt(var + rmsnorm_eps)
469+
x = x * (1.0 + weight.float())
470+
hidden_states.copy_(x.to(combined.dtype))
471+
472+
473+
def _cute_post_attn_ln_dispatch_fake(
474+
hidden_states: torch.Tensor,
475+
residual: torch.Tensor,
476+
weight: torch.Tensor,
477+
rmsnorm_eps: float,
478+
layer_name: str,
479+
) -> None:
480+
return
481+
482+
483+
direct_register_custom_op(
484+
op_name="cute_post_attn_ln_dispatch",
485+
op_func=_cute_post_attn_ln_dispatch_impl,
486+
mutates_args=["hidden_states", "residual"],
487+
fake_impl=_cute_post_attn_ln_dispatch_fake,
488+
)

0 commit comments

Comments
 (0)