|
46 | 46 | # The op body reads from this dict at runtime (not at trace time). |
47 | 47 | _CUTE_MLP_REGISTRY: dict[str, "CutePagedAttentionImpl"] = {} |
48 | 48 |
|
49 | | -# 2026-04-26 (B-fix): attn-consume registry, populated by |
50 | | -# `CutePagedAttentionImpl.attach_fusion`. Same impl object as |
51 | | -# _CUTE_MLP_REGISTRY but keyed by ATTENTION layer name (e.g. |
52 | | -# `language_model.model.layers.3.self_attn.attn`), not the MLP key |
53 | | -# used by cute_phase_e_dispatch. Allows cute_attn_consume and |
54 | | -# cute_post_attn_ln_dispatch to look up the impl and read its |
55 | | -# Python-side flags at runtime — avoids the .item() host-device sync |
56 | | -# on a 0-dim tensor signal (which raises cudaErrorStreamCaptureInvalidated |
57 | | -# under CUDA graph capture, verified 2026-04-26). |
58 | | -_CUTE_ATTN_REGISTRY: dict[str, "CutePagedAttentionImpl"] = {} |
59 | | - |
60 | 49 |
|
61 | 50 | def _cute_mlp_forward_impl( |
62 | 51 | x: torch.Tensor, |
@@ -320,169 +309,3 @@ def _cute_residual_mirror_fake( |
320 | 309 | mutates_args=["residual_buf"], |
321 | 310 | fake_impl=_cute_residual_mirror_fake, |
322 | 311 | ) |
323 | | - |
324 | | - |
325 | | -# --- 2026-04-26: cute_attn_consume + cute_post_attn_ln_dispatch ---------------- |
326 | | -# B-fix: replace the dead-eliminated Python `if _fusion_active` consume branch |
327 | | -# at qwen3_5.py:466-476 and the dead-eliminated `if not _fusion_active` |
328 | | -# post_attention_layernorm gate at qwen3_5.py:490-496. |
329 | | -# |
330 | | -# WHY needed: the captured FX graph (verified 2026-04-26 via |
331 | | -# /root/.cache/vllm/torch_compile_cache/<hash>/rank_0_0/backbone/computation_graph.py) |
332 | | -# specialized BOTH gates at trace time on `_fusion_active = False` (the impl's |
333 | | -# __init__ default) — dynamo can't see the runtime mutation that happens inside |
334 | | -# the unified_attention opaque op. Result: the consume copy was DCE'd, the |
335 | | -# legacy Python o_proj + post_attn_LN ALWAYS ran, β-coop's rmsnorm_output / |
336 | | -# residual_output were never read by the captured graph. In dual-fire this |
337 | | -# happened to produce coherent output because paged populated `output` with |
338 | | -# Phase A and the Python pipeline applied o_proj + post_attn_LN over it. In |
339 | | -# solo (paged gated off, β-coop only), `output` stayed uninitialised and |
340 | | -# Python applied o_proj over junk → gibberish. |
341 | | -# |
342 | | -# Fix: route the consume / postln decision through a runtime tensor signal |
343 | | -# (`impl._fusion_active_signal`, 0-dim int32) that's mutated INSIDE the |
344 | | -# unified_attention op (invisible to dynamo's specialization) and read at |
345 | | -# runtime via .item() inside these opaque ops. Both ops always run, dispatch |
346 | | -# at runtime via the signal value: |
347 | | -# signal == 0 : non-fusion mode (β-coop didn't fire). consume no-ops; |
348 | | -# postln applies the fused-residual RMSNorm in-place over |
349 | | -# the Python o_proj's wo_out. |
350 | | -# signal > 0 : fusion mode (β-coop fired with N=signal tokens). consume |
351 | | -# copies β-coop's rmsnorm_output → self_attention_output and |
352 | | -# residual_output → residual; postln no-ops (β-coop's Phase |
353 | | -# 1C already produced LN(post_input_LN_residual + wo_out)·γ). |
354 | | -# |
355 | | -# residual_buf and gate_buf are passed to consume as PHANTOM inputs (not |
356 | | -# read inside the body) — their sole purpose is to give the cute_residual_mirror |
357 | | -# and cute_residual_mirror(gate_buf, ...) ops observable downstream readers |
358 | | -# in the captured graph, which prevents dynamo's DCE from dropping them |
359 | | -# (verified empirically that mutates_args alone is NOT sufficient against |
360 | | -# DCE — the ops were dead-eliminated despite mutates_args=["residual_buf"] |
361 | | -# until a downstream reader was added). |
362 | | - |
363 | | - |
364 | | -def _cute_attn_consume_impl( |
365 | | - self_attention_output: torch.Tensor, # mutated [num_tokens, hidden_dim] BF16 |
366 | | - residual: torch.Tensor, # mutated [num_tokens, hidden_dim] BF16 |
367 | | - rmsnorm_output: torch.Tensor, # impl.rmsnorm_output [max_num_seqs, hidden_dim] BF16 |
368 | | - residual_output: torch.Tensor, # impl.residual_output [max_num_seqs, hidden_dim] BF16 |
369 | | - residual_buf: torch.Tensor, # phantom for cute_residual_mirror dep |
370 | | - gate_buf: torch.Tensor, # phantom for gate-mirror dep |
371 | | - layer_name: str, # registry key into _CUTE_ATTN_REGISTRY |
372 | | -) -> None: |
373 | | - """If β-coop fired this step: copy its outputs into model-side tensors. |
374 | | -
|
375 | | - Reads `impl._phase_e_use_beta_coop` (Python attr) at runtime via |
376 | | - `_CUTE_ATTN_REGISTRY[layer_name]` — no .item() call, no CUDA sync, |
377 | | - safe under CUDA graph capture. Reset to False at top of impl.forward, |
378 | | - set to True only on successful β-coop launch — so True ⇔ β-coop wrote |
379 | | - rmsnorm_output and residual_output for THIS forward call. |
380 | | - """ |
381 | | - impl = _CUTE_ATTN_REGISTRY.get(layer_name) |
382 | | - # 2026-04-26 (B-fix v2): gate on `_fusion_bound` (set once at |
383 | | - # attach_fusion, stable across warmup + runtime) rather than |
384 | | - # `_phase_e_use_beta_coop` (set per-step inside impl.forward — not |
385 | | - # consistently True at warmup capture time, so the captured segment |
386 | | - # would skip the consume kernels and replay would never fill |
387 | | - # self_attention_output from β-coop's outputs). With _fusion_bound: |
388 | | - # capture always sees True for fusion-bound full-attn layers, |
389 | | - # consume kernels always captured. Cost: if β-coop ever fails to |
390 | | - # fire at runtime (e.g. predicate fails), consume reads stale |
391 | | - # impl.rmsnorm_output. Mitigated by the predicate hard-gate landed |
392 | | - # in the prior commit which prevents silent β-coop fallthrough on |
393 | | - # cooperative-launch-too-large. |
394 | | - if impl is None or not getattr(impl, "_fusion_bound", False): |
395 | | - # Non-fusion / non-bound: leave self_attention_output as-is (Python |
396 | | - # o_proj already wrote it) and residual untouched. |
397 | | - return |
398 | | - # Fusion mode: β-coop's Phase 1C produced these. Bound by buffer capacity |
399 | | - # defensively (matches the original Python consume branch). |
400 | | - nat = min(self_attention_output.shape[0], rmsnorm_output.shape[0]) |
401 | | - self_attention_output[:nat].copy_(rmsnorm_output[:nat]) |
402 | | - if nat < self_attention_output.shape[0]: |
403 | | - # Match the prior `if nat < num_tokens: self_attention_output[nat:].zero_()` |
404 | | - # — keeps unused rows deterministic across decode steps. |
405 | | - self_attention_output[nat:].zero_() |
406 | | - residual[:nat].copy_(residual_output[:nat]) |
407 | | - |
408 | | - |
409 | | -def _cute_attn_consume_fake( |
410 | | - self_attention_output: torch.Tensor, |
411 | | - residual: torch.Tensor, |
412 | | - rmsnorm_output: torch.Tensor, |
413 | | - residual_output: torch.Tensor, |
414 | | - residual_buf: torch.Tensor, |
415 | | - gate_buf: torch.Tensor, |
416 | | - layer_name: str, |
417 | | -) -> None: |
418 | | - return |
419 | | - |
420 | | - |
421 | | -direct_register_custom_op( |
422 | | - op_name="cute_attn_consume", |
423 | | - op_func=_cute_attn_consume_impl, |
424 | | - # Both self_attention_output and residual are mutated when fusion fires; |
425 | | - # the phantom inputs are read-only. |
426 | | - mutates_args=["self_attention_output", "residual"], |
427 | | - fake_impl=_cute_attn_consume_fake, |
428 | | -) |
429 | | - |
430 | | - |
431 | | -def _cute_post_attn_ln_dispatch_impl( |
432 | | - hidden_states: torch.Tensor, # mutated [num_tokens, hidden_dim] BF16 |
433 | | - residual: torch.Tensor, # mutated [num_tokens, hidden_dim] BF16 |
434 | | - weight: torch.Tensor, # post_attention_layernorm.weight [hidden_dim] BF16 |
435 | | - rmsnorm_eps: float, |
436 | | - layer_name: str, # registry key into _CUTE_ATTN_REGISTRY |
437 | | -) -> None: |
438 | | - """If β-coop did NOT fire: apply fused-residual post_attention_layernorm. |
439 | | -
|
440 | | - Mirrors `_forward_static_with_residual` in vllm/nvllm/layers/layernorm.py: |
441 | | - combined = hidden_states + residual |
442 | | - residual = combined |
443 | | - x = combined.float() |
444 | | - var = x.pow(2).mean(dim=-1, keepdim=True) |
445 | | - x = x * torch.rsqrt(var + eps) |
446 | | - x = x * (1.0 + weight.float()) |
447 | | - hidden_states = x.to(combined.dtype) |
448 | | -
|
449 | | - When β-coop fired, its Phase 1C already produced this exact output into |
450 | | - hidden_states via cute_attn_consume above, and residual already holds |
451 | | - residual_post_attn — skip to avoid double-LN. |
452 | | -
|
453 | | - Reads `impl._phase_e_use_beta_coop` (Python attr) — no .item() needed, |
454 | | - CUDA-graph-safe. See cute_attn_consume docstring for the gate semantics. |
455 | | - """ |
456 | | - impl = _CUTE_ATTN_REGISTRY.get(layer_name) |
457 | | - # See cute_attn_consume docstring above for why we gate on _fusion_bound |
458 | | - # rather than _phase_e_use_beta_coop. Symmetric: when consume fires, |
459 | | - # post_attn_LN must skip; when consume no-ops, post_attn_LN must apply. |
460 | | - if impl is not None and getattr(impl, "_fusion_bound", False): |
461 | | - # Fusion mode: β-coop already did post_attn_LN. Skip. |
462 | | - return |
463 | | - # Non-fusion mode: replicate _forward_static_with_residual in-place. |
464 | | - combined = hidden_states + residual |
465 | | - residual.copy_(combined) |
466 | | - x = combined.float() |
467 | | - var = x.pow(2).mean(dim=-1, keepdim=True) |
468 | | - x = x * torch.rsqrt(var + rmsnorm_eps) |
469 | | - x = x * (1.0 + weight.float()) |
470 | | - hidden_states.copy_(x.to(combined.dtype)) |
471 | | - |
472 | | - |
473 | | -def _cute_post_attn_ln_dispatch_fake( |
474 | | - hidden_states: torch.Tensor, |
475 | | - residual: torch.Tensor, |
476 | | - weight: torch.Tensor, |
477 | | - rmsnorm_eps: float, |
478 | | - layer_name: str, |
479 | | -) -> None: |
480 | | - return |
481 | | - |
482 | | - |
483 | | -direct_register_custom_op( |
484 | | - op_name="cute_post_attn_ln_dispatch", |
485 | | - op_func=_cute_post_attn_ln_dispatch_impl, |
486 | | - mutates_args=["hidden_states", "residual"], |
487 | | - fake_impl=_cute_post_attn_ln_dispatch_fake, |
488 | | -) |
0 commit comments