Skip to content

Commit 5a0311c

Browse files
Natfiiclaude
andcommitted
fix(cute): C2 plumbing — residual/gate mirror op + β-coop predicate hard-gate
Two correctness bugs and one no-silent-fallback hardening: 1) residual_buf + gate_buf dynamo dead-elimination Both qwen3_5.py call sites for the BF16 residual / gate mirror `.copy_()` lived inside `try/except` blocks whose protected line `get_forward_context().attn_metadata[layer_name]` raises at torch.compile trace time (forward_context is None). Dynamo concluded the try body was always-caught dead code and the captured PIECEWISE graph dropped the .copy_. At runtime the buffers stayed at the CUDA-graph-allocator-zeroed value → β-coop / paged read zeros → gibberish. Verified 2026-04-26 via /tmp/nvllm-dumps: residual_in absmax=0.0 across all 16 full-attn layers pre-fix. Fix: new `cute_residual_mirror` opaque op in _mlp_op.py with `mutates_args=["residual_buf"]`. The first-pass attempt with `mutates_args=[]` was still dead-eliminated — the mutates_args declaration is what tells torch.compile the op has a real side effect on a tracked tensor. Both qwen3_5.py call sites (Qwen3_5DecoderLayer.forward residual_buf @L427, Qwen3_5Attention.forward gate_buf @l253) now route through the op. This was an actual bug present before β-coop ever fired: paged kernel was silently reading zero residual_buf in any PIECEWISE deployment using fusion. Standalone correctness win. 2) β-coop predicate hard-gate (no-silent-fallback) `_will_fire_beta_coop_pre` and `_use_beta_coop` previously bypassed the `(64 * num_seqs) <= _resident_cap` cooperative-launch fitness check when forced_path == "coop", under the assumption "user asked for coop, they know what they're doing." But on multi-seq decode (e.g. nat=3 batches) the fixed grid exceeds the resident cap → CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE → except-handler fallthrough to β-lite. β-lite is MLP-only with no attention → silent gibberish. Fix: cooperative-launch fitness is now a HARD gate regardless of forced_path. If the grid won't fit, paged_attention_forward stays in the decode path. Predicate is duplicated at two sites (`_will_fire_beta_coop_pre` for the paged-skip decision and `_use_beta_coop` for the dispatch) — kept in sync via comment cross-refs. Per memory:feedback_no_silent_fallbacks. 3) C2 attn-output-gate wired through β-coop kernel phase_e_kernel.py: gate_ptr + gate_fused flag added to PhaseE_Beta_Kernel.run_beta_coop_full and to the JIT signature. gate_fused == 0 disables the multiply (back-compat for callers that don't supply gate_buf). _backend.py β-coop dispatch passes self.gate_buf[:nat]. Mirrors paged kernel.py:1555-1569. This is the consumer side of fix #1 — without #1 the gate buffer was always zero so the flag couldn't have been observed. 4) Env-gated tensor dump harness (kept per feedback_keep_debug_harnesses) _backend.py β-coop branch: CUTE_DUMP_TENSORS=1 dumps {residual_in, query, gate, residual_out, rmsnorm_out} per (layer × decode step), bounded to 3 steps × 16 layers. Files land in /tmp/nvllm-dumps/. serve-cute.sh adds the bind mount and env passthrough. Used to bisect this bug; keeping for the next graph-capture investigation. Also: BETA_DIFF harness clones paged's wo_output / rmsnorm_output / residual_output before β-coop overwrites them, then logs the delta. Gated on CUTE_DEBUG_FUSION=1, only fires in dual-fire mode (skipped when paged is gated off). Verified BETA_DIFF=0 with FIXED inputs — β-coop math byte-identical to paged. Validation matrix (2026-04-26 EOD, ig1/Qwen3.5-27B-NVFP4): - PIECEWISE + paged-only: COHERENT ✓ - PIECEWISE + dual-fire (paged + β-coop): COHERENT ✓ BETA_DIFF=0 - PIECEWISE + solo β-coop: GIBBERISH ✗ (remaining) - EAGER + solo β-coop: COHERENT ✓ The remaining solo-β-coop gibberish under PIECEWISE is upstream of β-coop entirely — layer 3 inputs (the first full-attn layer, after 3 untouched linear-attn layers) differ between dual-fire and solo modes for the same prompt + seed. Captured CUDA graph layout / compile artifact differs depending on whether paged is also in the captured segment. Investigation paths in memory:project_beta_coop_residual_solo_bug. Side-by-side dumps preserved at /tmp/nvllm-dumps-{dualfire,solo} (80 files each) for next session. Refs: memory:project_beta_coop_residual_solo_bug memory:project_uber_kernel_migration memory:feedback_no_silent_fallbacks memory:feedback_keep_debug_harnesses memory:feedback_layer_output_contract Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 54da780 commit 5a0311c

5 files changed

Lines changed: 301 additions & 56 deletions

File tree

scripts/serve-cute.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,12 @@ docker run -d \
7070
--network host \
7171
-v "$HOME/.cache/huggingface:/root/.cache/huggingface" \
7272
-v "$HOME/.cache/flashinfer:/root/.cache/flashinfer" \
73+
-v "/tmp/nvllm-dumps:/tmp/nvllm-dumps" \
7374
-e VLLM_NVFP4_GEMM_BACKEND=cutlass \
7475
-e VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \
7576
-e PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
7677
-e CUTE_DEBUG_FUSION="${CUTE_DEBUG_FUSION:-0}" \
78+
-e CUTE_DUMP_TENSORS="${CUTE_DUMP_TENSORS:-0}" \
7779
-e CUTE_MLP_FUSION="${CUTE_MLP_FUSION:-1}" \
7880
-e CUTE_ATTN_FUSION="${CUTE_ATTN_FUSION:-1}" \
7981
-e CUTE_DEBUG_MLP_FUSION="${CUTE_DEBUG_MLP_FUSION:-0}" \

vllm/nvllm/models/qwen3_5.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -253,20 +253,15 @@ def forward(
253253
# is a cheap one-off BF16 memcpy; it avoids the old model->impl flag
254254
# side-channel that was flagged as fragile.
255255
if gate is not None:
256-
from vllm.forward_context import get_forward_context
257-
256+
# 2026-04-26: gate_buf mirror via the same opaque op as
257+
# residual_buf — the prior plain-Python .copy_() inside
258+
# try/except (which protected the trace-time-failing
259+
# `attn_metadata[...]` lookup) was being dead-eliminated by
260+
# @support_torch_compile dynamo. Same root cause, same fix.
258261
impl = self.attn.impl
259262
gate_buf = getattr(impl, "gate_buf", None)
260263
if gate_buf is not None:
261-
try:
262-
nat = (
263-
get_forward_context()
264-
.attn_metadata[self.attn.layer_name]
265-
.num_actual_tokens
266-
)
267-
gate_buf[:nat].copy_(gate[:nat])
268-
except (RuntimeError, KeyError, AttributeError, TypeError):
269-
pass
264+
torch.ops.vllm.cute_residual_mirror(gate_buf, gate)
270265

271266
attn_output = self.attn(q, k, v)
272267

@@ -432,18 +427,23 @@ def _ct_mark(label: str) -> None:
432427
impl = None
433428
if self.layer_type == "full_attention":
434429
impl = self.self_attn.attn.impl
435-
fusion_could_run = getattr(impl, "_fusion_bound", False)
436-
if fusion_could_run:
437-
try:
438-
from vllm.forward_context import get_forward_context
439-
440-
attn_md = get_forward_context().attn_metadata[
441-
self.self_attn.attn.layer_name
442-
]
443-
nat = attn_md.num_actual_tokens
444-
impl.residual_buf[:nat].copy_(residual[:nat])
445-
except (RuntimeError, KeyError, AttributeError, TypeError):
446-
pass
430+
# 2026-04-26: residual mirror via opaque custom op. The prior
431+
# plain-Python .copy_(residual) was inside a try/except whose
432+
# protected lookup `get_forward_context().attn_metadata[...]`
433+
# threw at torch.compile trace time. dynamo concluded the
434+
# try body was always-caught dead code and the captured
435+
# graph dropped the .copy_. At runtime residual_buf stayed at
436+
# the CUDA-graph-allocator-zeroed value → β-coop read zeros
437+
# → gibberish.
438+
#
439+
# The opaque op preserves the side effect across graph capture.
440+
# `residual_buf` is a declared mutates_args so torch.compile
441+
# tracks the mutation as a real side effect (the prior op
442+
# version with `mutates_args=[]` was still dead-eliminated).
443+
if getattr(impl, "_fusion_bound", False):
444+
torch.ops.vllm.cute_residual_mirror(
445+
impl.residual_buf, residual
446+
)
447447
_ct_mark("residual_mirror")
448448

449449
self_attention_output = torch.empty_like(hidden_states)

vllm/v1/attention/backends/cute_paged/_backend.py

Lines changed: 154 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,35 +1072,75 @@ def forward(
10721072
num_seqs = len(attn_metadata.seq_lens)
10731073
padded_num_seqs = num_seqs # graph capture overrides via metadata
10741074

1075-
result = paged_attention_forward(
1076-
query=query[:num_actual_tokens],
1077-
kv_cache=kv_cache,
1078-
page_table=attn_metadata.block_table,
1079-
seq_lens=attn_metadata.seq_lens,
1080-
scale=self.scale,
1081-
k_scale=k_scale,
1082-
v_scale=v_scale,
1083-
page_size=64,
1084-
query_start_loc=attn_metadata.query_start_loc,
1085-
wo_weight=wo_weight,
1086-
wo_scales=wo_scales,
1087-
wo_global_scale=wo_global_scale,
1088-
wo_output=wo_output,
1089-
rmsnorm_gamma=rmsnorm_gamma,
1090-
rmsnorm_residual=rmsnorm_residual,
1091-
rmsnorm_output=rmsnorm_output,
1092-
residual_output=residual_output,
1093-
arrival_count=arrival_count,
1094-
rmsnorm_eps=rmsnorm_eps,
1095-
gate_buf=gate_buf,
1096-
padded_num_seqs=padded_num_seqs,
1075+
# C2: gate paged_attention_forward off on decode when β-coop is
1076+
# going to fire — β-coop is the sole Phase A+B+C+3+4 uber-kernel
1077+
# in that path, paged becomes redundant double-fire. We replicate
1078+
# the _use_beta_coop predicate computed below so the gate matches.
1079+
# NOTE: kept commented-out OFF gate sites (none here) per
1080+
# feedback_comment_not_delete; the only structural change is the
1081+
# `if _will_fire_beta_coop_pre:` wrapper around the paged call.
1082+
_phase_e_env_pre = _phase_e_env_config()
1083+
# 2026-04-26: cooperative-launch fitness (64*num_seqs <= _resident_cap)
1084+
# is a HARD gate even in forced-coop mode. Previously bypassed when
1085+
# forced_path=="coop", which caused CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE
1086+
# on multi-seq decode (e.g., nat=3 batches). Because the paged-skip
1087+
# below assumes β-coop will run, a coop launch failure left
1088+
# self.rmsnorm_output stale → β-lite (MLP-only) ran on garbage →
1089+
# silent gibberish. Keep this in sync with `_use_beta_coop` below.
1090+
_will_fire_beta_coop_pre = (
1091+
_phase_e_env_pre.enabled
1092+
and is_decode_only
1093+
and use_fusion
1094+
and getattr(self, "_phase_e_coop_kernel", None) is not None
1095+
and getattr(self, "_mlp_fusion_bound", False)
1096+
and num_actual_tokens <= getattr(self, "_fusion_max_num_seqs", 0)
1097+
and (64 * num_seqs) <= getattr(self, "_resident_cap", 0)
1098+
and _phase_e_env_pre.forced_path in ("coop", "auto")
10971099
)
1100+
if _will_fire_beta_coop_pre:
1101+
result = None
1102+
# Mark snapshots stale so the BETA_DIFF harness skips below.
1103+
self._debug_paged_res = None
1104+
else:
1105+
result = paged_attention_forward(
1106+
query=query[:num_actual_tokens],
1107+
kv_cache=kv_cache,
1108+
page_table=attn_metadata.block_table,
1109+
seq_lens=attn_metadata.seq_lens,
1110+
scale=self.scale,
1111+
k_scale=k_scale,
1112+
v_scale=v_scale,
1113+
page_size=64,
1114+
query_start_loc=attn_metadata.query_start_loc,
1115+
wo_weight=wo_weight,
1116+
wo_scales=wo_scales,
1117+
wo_global_scale=wo_global_scale,
1118+
wo_output=wo_output,
1119+
rmsnorm_gamma=rmsnorm_gamma,
1120+
rmsnorm_residual=rmsnorm_residual,
1121+
rmsnorm_output=rmsnorm_output,
1122+
residual_output=residual_output,
1123+
arrival_count=arrival_count,
1124+
rmsnorm_eps=rmsnorm_eps,
1125+
gate_buf=gate_buf,
1126+
padded_num_seqs=padded_num_seqs,
1127+
)
1128+
1129+
# --- BETA_DIFF harness: snapshot paged's outputs so we can diff
1130+
# against β-coop's overwrite later. Gated on CUTE_DEBUG_FUSION=1.
1131+
# Only fires when paged actually ran (else clause).
1132+
# See memory:project_beta_coop_residual_solo_bug for protocol.
1133+
if _DEBUG_FUSION and use_fusion and is_decode_only:
1134+
self._debug_paged_wo = self.wo_output.detach().clone()
1135+
self._debug_paged_rms = self.rmsnorm_output.detach().clone()
1136+
self._debug_paged_res = self.residual_output.detach().clone()
10981137

10991138
# --- DEBUG: fusion diagnostic (CUTE_DEBUG_FUSION=1) ---
11001139
# Compares kernel's impl.wo_output (Phase B GEMV) against a Python
11011140
# reference computed from the kernel's own Phase A output (`result`)
11021141
# and a one-time-dequantized W_O. Proves whether Phase B is faithful.
1103-
if _DEBUG_FUSION and use_fusion:
1142+
# Skip when paged was gated off (result is None).
1143+
if _DEBUG_FUSION and use_fusion and result is not None:
11041144
self._debug_fusion_diff(
11051145
result=result,
11061146
num_actual_tokens=num_actual_tokens,
@@ -1173,16 +1213,18 @@ def forward(
11731213
_resident_cap = getattr(self, "_resident_cap", 0)
11741214
# Task 16: β-coop dispatch. β-coop requires the unified kernel
11751215
# attached in attach_mlp_fusion (CUTE_PHASE_E_FUSION=1 at attach
1176-
# time). forced_path="coop" always routes here; "auto" routes here
1177-
# when the full grid fits the resident cap for a single cooperative
1178-
# launch (otherwise β-lite's two-kernel path handles it).
1216+
# time). 2026-04-26: cooperative-launch fitness is a HARD gate
1217+
# for both forced_path values — was previously bypassed when
1218+
# forced_path=="coop", causing CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE
1219+
# on multi-seq decode and silent gibberish (β-lite is MLP-only,
1220+
# provides no attention fallback). Must stay in sync with
1221+
# `_will_fire_beta_coop_pre` above.
11791222
_coop_attached = getattr(self, "_phase_e_coop_kernel", None) is not None
1180-
_use_beta_coop = _phase_e_active and _coop_attached and (
1181-
_phase_e_env.forced_path == "coop"
1182-
or (
1183-
_phase_e_env.forced_path == "auto"
1184-
and _total_ctas <= _resident_cap
1185-
)
1223+
_use_beta_coop = (
1224+
_phase_e_active
1225+
and _coop_attached
1226+
and _total_ctas <= _resident_cap
1227+
and _phase_e_env.forced_path in ("coop", "auto")
11861228
)
11871229
_use_beta_lite = (
11881230
_phase_e_active
@@ -1261,9 +1303,81 @@ def forward(
12611303
# Caller-supplied residual_output so self.residual_output
12621304
# reflects residual_post_attn (Phase-1 Phase-C output).
12631305
residual_output=self.residual_output[:nat],
1306+
# C2: Qwen3.5 attn output gate — buffer was filled by
1307+
# qwen3_5.py:267 from the q_proj's gate slice. Mirrors
1308+
# the paged kernel's `gate_buf=` plumbing.
1309+
gate_buf=self.gate_buf[:nat],
12641310
)
12651311
self._phase_e_consumed = True
12661312
self._phase_e_use_beta_coop = True
1313+
# 2026-04-26: ENV-GATED dump for off-line math verification.
1314+
# CUTE_DUMP_TENSORS=1 enables; bounded to first 3 decode
1315+
# steps × 16 full-attn layers so disk doesn't bloat. Files
1316+
# land in /tmp/nvllm-dumps/layer{N}_step{S}_{name}.pt.
1317+
# See ~/jupyterlab/beta_coop_kernel_dump_compare.ipynb.
1318+
if os.environ.get("CUTE_DUMP_TENSORS", "0") == "1":
1319+
_dump_dir = "/tmp/nvllm-dumps"
1320+
os.makedirs(_dump_dir, exist_ok=True)
1321+
_step_counter = getattr(self, "_dump_step_counter", 0)
1322+
if _step_counter < 3 * 16:
1323+
_layer_segs = getattr(
1324+
layer, "layer_name", "<layer>").split(".")
1325+
_layer_digits = [
1326+
p for p in _layer_segs if p.isdigit()]
1327+
_layer_idx = int(_layer_digits[0]) \
1328+
if _layer_digits else -1
1329+
_base = (f"{_dump_dir}/layer{_layer_idx}_"
1330+
f"step{_step_counter // 16}")
1331+
torch.save(
1332+
self.residual_buf[:nat].detach().clone(),
1333+
f"{_base}_residual_in.pt")
1334+
torch.save(
1335+
query[:nat].detach().clone(),
1336+
f"{_base}_query.pt")
1337+
torch.save(
1338+
self.gate_buf[:nat].detach().clone(),
1339+
f"{_base}_gate.pt")
1340+
torch.save(
1341+
self.residual_output[:nat].detach().clone(),
1342+
f"{_base}_residual_out.pt")
1343+
torch.save(
1344+
self.rmsnorm_output[:nat].detach().clone(),
1345+
f"{_base}_rmsnorm_out.pt")
1346+
self._dump_step_counter = _step_counter + 1
1347+
# --- BETA_DIFF harness: diff β-coop's overwrite vs paged.
1348+
# See memory:project_beta_coop_residual_solo_bug for protocol.
1349+
if (_DEBUG_FUSION and is_decode_only
1350+
and getattr(self, "_debug_paged_res", None) is not None):
1351+
nat_dbg = num_actual_tokens
1352+
wo_diff = (
1353+
self.wo_output[:nat_dbg]
1354+
- self._debug_paged_wo[:nat_dbg]
1355+
).abs()
1356+
rms_diff = (
1357+
self.rmsnorm_output[:nat_dbg].float()
1358+
- self._debug_paged_rms[:nat_dbg].float()
1359+
).abs()
1360+
res_diff = (
1361+
self.residual_output[:nat_dbg].float()
1362+
- self._debug_paged_res[:nat_dbg].float()
1363+
).abs()
1364+
# Also dump the raw β-coop residual_output[0, :8] and
1365+
# the corresponding paged value, so we can eyeball
1366+
# whether sentinel landed.
1367+
res_b = self.residual_output[0, :8].float().tolist()
1368+
res_p = self._debug_paged_res[0, :8].float().tolist()
1369+
logger.info(
1370+
"[BETA_DIFF] layer=%s nat=%d "
1371+
"wo:max=%.4e mean=%.4e | "
1372+
"rms:max=%.4e mean=%.4e | "
1373+
"res:max=%.4e mean=%.4e | "
1374+
"res_beta[0,:8]=%s | res_paged[0,:8]=%s",
1375+
getattr(layer, "layer_name", "<layer>"), nat_dbg,
1376+
wo_diff.max().item(), wo_diff.mean().item(),
1377+
rms_diff.max().item(), rms_diff.mean().item(),
1378+
res_diff.max().item(), res_diff.mean().item(),
1379+
res_b, res_p,
1380+
)
12671381
except Exception as e: # noqa: BLE001 — fail-closed, fall through to β-lite
12681382
logger.warning(
12691383
"CuTe Phase E β-coop launch failed (falling back to "
@@ -1394,7 +1508,14 @@ def forward(
13941508
# self._mlp_fusion_active = False
13951509
# --- END PHASE D2 DISABLED ---
13961510

1397-
output[:num_actual_tokens].copy_(result)
1511+
# C2: when paged_attention_forward was gated off (β-coop fired
1512+
# alone), `result` is None. β-coop wrote its outputs into
1513+
# self.rmsnorm_output / self.mlp_output / self.residual_output
1514+
# which the consume branch reads directly — `output` is the
1515+
# framework's unified attn-output buffer, not consumed in the
1516+
# fusion path. Skip the copy_ in that case.
1517+
if result is not None:
1518+
output[:num_actual_tokens].copy_(result)
13981519
return output
13991520

14001521
def _debug_fusion_diff(

vllm/v1/attention/backends/cute_paged/_mlp_op.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,3 +242,70 @@ def _cute_phase_e_dispatch_fake(
242242
# ε epilogue (Phase 4) was deleted in the same commit.
243243
#
244244
# See docs/research/uber_kernel_migration/q4_brainstorm_layer_LN_2026-04-25.md.
245+
246+
247+
# --- 2026-04-26: cute_residual_mirror -----------------------------------------
248+
# Opaque op for the residual mirror copy that qwen3_5.py's decoder forward
249+
# does at layer entry: `impl.residual_buf[:nat].copy_(residual[:nat])`.
250+
#
251+
# Why an op: the prior plain-Python `.copy_()` was inside an `if fusion_could_run:
252+
# try: ... attn_md = get_forward_context().attn_metadata[layer_name] ...`
253+
# block. Under @support_torch_compile (model.forward), dynamo traced the
254+
# get_forward_context lookup (None at trace time) → TypeError → except
255+
# pass. The captured graph then dropped the `.copy_` because (a) the
256+
# inferred trace path always took the except branch and (b) `impl.residual_buf`
257+
# is mutated state torch.compile doesn't track as a graph output. Result at
258+
# runtime: residual_buf stayed at the CUDA-graph-allocator-zeroed value;
259+
# β-coop read zeros; gibberish. (Verified 2026-04-26 via /tmp/nvllm-dumps —
260+
# residual_in absmax=0.0000 across all 16 full-attn layers.)
261+
#
262+
# Wrapping the copy in an opaque custom op makes it a black-box side-effect
263+
# from torch.compile's perspective — it's preserved across graph capture and
264+
# always runs at runtime.
265+
266+
_RES_MIRROR_DIAG_SEEN: set[int] = set()
267+
268+
269+
def _cute_residual_mirror_impl(
270+
residual_buf: torch.Tensor,
271+
residual: torch.Tensor,
272+
) -> None:
273+
"""Copy `residual` into `residual_buf` (in-place mutation).
274+
275+
Direct buffer-passing replaces the prior registry-lookup design:
276+
`mutates_args=["residual_buf"]` tells torch.compile the op has a
277+
real side effect on a tracked tensor, so it isn't dead-eliminated.
278+
"""
279+
nat = residual.shape[0]
280+
if nat == 0:
281+
return
282+
nat = min(nat, residual_buf.shape[0])
283+
# 2026-04-26 DIAG: one-shot per residual_buf identity. Logs whether
284+
# the op fires at runtime + the input magnitude. Remove after ship.
285+
_key = id(residual_buf)
286+
if _key not in _RES_MIRROR_DIAG_SEEN:
287+
_RES_MIRROR_DIAG_SEEN.add(_key)
288+
logger.info(
289+
"[RES_MIRROR_OP] nat=%d residual_absmax=%.4e "
290+
"buf_shape=%s buf_pre_absmax=%.4e",
291+
nat,
292+
residual.float().abs().max().item(),
293+
tuple(residual_buf.shape),
294+
residual_buf.float().abs().max().item(),
295+
)
296+
residual_buf[:nat].copy_(residual[:nat])
297+
298+
299+
def _cute_residual_mirror_fake(
300+
residual_buf: torch.Tensor,
301+
residual: torch.Tensor,
302+
) -> None:
303+
return
304+
305+
306+
direct_register_custom_op(
307+
op_name="cute_residual_mirror",
308+
op_func=_cute_residual_mirror_impl,
309+
mutates_args=["residual_buf"],
310+
fake_impl=_cute_residual_mirror_fake,
311+
)

0 commit comments

Comments
 (0)