Skip to content

Fix CUDA Attention dispatch: skip MEA when head_size != v_head_size in GQA#28358

Merged
justinchuby merged 1 commit intomainfrom
fix-attention-head-size-mismatch
May 5, 2026
Merged

Fix CUDA Attention dispatch: skip MEA when head_size != v_head_size in GQA#28358
justinchuby merged 1 commit intomainfrom
fix-attention-head-size-mismatch

Conversation

@justinchuby
Copy link
Copy Markdown
Contributor

@justinchuby justinchuby commented May 5, 2026

Summary

Problem

The Memory-Efficient Attention (MEA) path crashes with cudaErrorMisalignedAddress when:

  • GQA mode (q_num_heads != kv_num_heads)
  • head_size != v_head_size (e.g., Q.head_dim=256, K.head_dim=512)
  • seq_len >= 4 (Flash Attention not eligible due to attention mask)

This is because MEA's LaunchUngroup requires equal head sizes, but the dispatch logic only checked this constraint for the past_key case (line 1380), not the general GQA case.

Fix

Skip MEA for GQA when head sizes differ. The Unfused Attention fallback handles this correctly.

Affected Models

Gemma 4 was not affected. This was a previously incorrect graph. But the fix is still good to have that improves robustness anyways.

Gemma4 (google/gemma-4-e2b-it) with KV sharing:

  • Layers 15-34 borrow K,V from source layers
  • Q projection: 1536 → 2048 (8 heads × 256)
  • K/V from source: [batch, 1, seq, 512]
  • head_size = 256, v_head_size = 512

Testing

Minimal repro (from #28357):

# Attention(Q=[1,S,2048], K=[1,S,512], V=[1,S,512], q_num_heads=8, kv_num_heads=1)
# Before fix: seq=4+ crashes with misaligned address
# After fix: all seq lengths work

Full Gemma4 decoder (35 layers, 15 GQA + 20 standard Attention):

  • Prefill seq=32: ✅
  • Decode seq=1: ✅

Fixes #28357

…n GQA

The Memory-Efficient Attention (MEA) path in the CUDA Attention kernel
crashes with misaligned address when q_num_heads != kv_num_heads (GQA
mode) and head_size != v_head_size. This happens because MEA's
LaunchUngroup requires equal head sizes, but the dispatch logic only
checked this constraint for the past_key case, not the general GQA case.

Add the missing check: skip MEA for GQA when head_size != v_head_size,
allowing the Unfused Attention fallback to handle it correctly.

This fixes Gemma4 models with KV sharing where Q has head_dim=256 but
shared K,V have head_dim=512. CPU EP handled this correctly; CUDA EP
crashed at seq_len >= 4 when Flash Attention was not eligible.

Fixes #28357

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchu@microsoft.com>
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary

Clean, well-targeted one-line fix for a CUDA crash (cudaErrorMisalignedAddress) when MEA is used with GQA and mismatched Q/KV head dimensions. The root cause analysis is accurate — LaunchUngroup uses float2* reinterpret casts and a single head_size, so it requires head_size == v_head_size. The fix correctly gates MEA eligibility, and the unfused fallback explicitly supports this configuration.

Positives:

  • The new condition (!is_gqa || head_size == v_head_size) is logically independent from the existing (past_key == nullptr || head_size == v_head_size) — each guards a distinct internal function (LaunchUngroup vs LaunchConcatNewToPastKV), and both are needed.
  • Updated comments clearly document the two separate failure modes.
  • Defense-in-depth ORT_ENFORCE at line 722 inside RunMemoryEfficientAttention remains as a safety net.

Comment thread onnxruntime/core/providers/cuda/llm/attention.cc
@justinchuby justinchuby merged commit 1f25783 into main May 5, 2026
91 of 93 checks passed
@justinchuby justinchuby deleted the fix-attention-head-size-mismatch branch May 5, 2026 14:13
@titaiwangms
Copy link
Copy Markdown
Contributor

I will add tests.

titaiwangms added a commit to titaiwangms/onnxruntime that referenced this pull request May 5, 2026
…coverage

Addresses microsoft#28351 sub-items REG, HS4, 1c, 1e:

* HS4 (production): Add (head_size % 4 == 0) clause to the MEA dispatch
  predicate at core/providers/cuda/llm/attention.cc as forward-looking
  defense-in-depth. The clause is REDUNDANT TODAY: has_memory_efficient_
  attention already enforces (qk_head_size & 7) == 0 (i.e. % 8) at
  contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h:71-72,
  which strictly implies % 4. We are not closing a current bug. The clause
  is kept as the strictest dtype-agnostic alignment floor that CUTLASS
  FMHA's BiasLoader actually requires (BiasLoader hardcodes a 128-bit /
  sizeof_bits-element alignment on Q/bias loads — 4 elements for fp32).
  Once microsoft#28365 lands and BiasLoader switches to
  kAlignmentA / DispatchIsAligned, MEA's own % 8 invariant will be loosened
  and this clause becomes load-bearing, preventing a correctness regression.
  The new comment block at the predicate site cites microsoft#28365 so the next
  maintainer can identify the right moment to delete it.

* REG (test): TestONNXAttentionGQAAsymmetricHeadSize pins the asymmetric
  v_head_size != head_size GQA path on fp16 and bf16 to guard against
  regression of the microsoft#28358 fix that removed the LaunchUngroup head_size ==
  v_head_size ENFORCE.

* HS4 (test): TestONNXAttentionGQAHeadSizeMod4 sweeps head_size in
  {6, 10, 12, 16, 24}. Today head_sizes 6/10/12 are filtered upstream by
  MEA's % 8 gate and take the unfused fall-through path; this test pins
  that fall-through stays numerically correct. 16/24 satisfy both % 8 and
  % 4 and exercise the MEA happy path. Once microsoft#28365 relaxes MEA's % 8
  invariant, head_sizes 6/10 will start exercising the HS4 host-side gate
  directly.

* 1c (test): TestONNXAttentionGQAOutputQK pins the GQA + qk_matmul_output_mode
  combination (kQK raw scaled QK output) which previously had no test
  coverage. Threads an optional output_qk parameter through common.py's
  create_attention_graph_prompt and attention_prompt_func.

* 1e (test): TestONNXAttentionGQASoftcapFloat32 pins the fp32 + softcap +
  GQA combination (symmetric and asymmetric V head). fp32 GQA always falls
  through to the unfused path on CUDA; existing softcap tests are fp16/bf16,
  so the fp32 unfused softcap branch had no parity coverage.

All 10 new tests pass on H100 (sm_90a). Full test_gqa.py: 91 passed,
3 pre-existing flakes ('Output mismatch between two runs' determinism
checks in unrelated decode-flash classes — not regressions from this
change).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
titaiwangms added a commit to titaiwangms/onnxruntime that referenced this pull request May 5, 2026
… tighten output_qk validation

Three substantive items + one docstring fix from round-2 reviewer
feedback (bot + internal multi-reviewer consolidation).

* core/providers/cuda/llm/attention.cc — drop the host-side
  `head_size %% 4 == 0` HS4 clause from the MEA-eligibility predicate
  and remove its multi-paragraph comment block. The clause is fully
  redundant today (`has_memory_efficient_attention` already requires
  `(qk_head_size & 7) == 0`, which strictly implies %4) and the comment
  it carried made dtype-aware alignment claims that are wrong for fp16
  / bf16 (BiasLoader needs an 8-element stride, not 4, for those
  dtypes). The dtype-aware alignment floor properly belongs in the
  BiasLoader fix (microsoft#28365), not as a vestigial
  redundant clause here. Predicate is now exactly the upstream/main
  shape for HS4 purposes.

* test/python/transformers/test_onnx_attention/test_gqa.py — delete
  TestONNXAttentionGQAHeadSizeMod4. With the HS4 clause gone there is
  no host-side gate left to validate; the parameterized sweep was
  exercising routing equivalence vs the unfused fall-through, which is
  already covered by the broader MEA / unfused tests.

* test/python/transformers/test_onnx_attention/common.py — tighten the
  output_qk parameter validation to `output_qk in {0, 1, 2, 3}` or
  `None`. The previous `is not None and >= 0` guard caught the C++
  `kNone = -1` sentinel but still silently accepted invalid modes
  4 / 5, which would build an ONNX graph with an out-of-range
  `qk_matmul_output_mode` attribute and let the unfused CUDA kernel
  populate the 4th output as raw-QK regardless. Validation now raises
  immediately with a clear message at the helper boundary; the
  binding-allocation site downstream is simplified to `is not None`
  since validation has already happened. NOTE block + both helper
  docstrings updated to spell out the contract: `None` = disabled;
  `{0,1,2,3}` = the corresponding QKMatMulOutputMode; anything else
  raises.

* test/python/transformers/test_onnx_attention/test_gqa.py — fix the
  TestONNXAttentionGQAAsymmetricHeadSize docstring. The pre-microsoft#28358
  `head_size == v_head_size` ENFORCE in LaunchUngroup is an MEA-path
  enforcement (LaunchUngroup is the GQA head-expansion helper used by
  MEA before its FMHA kernel), not an unfused-path one. Docstring now
  correctly attributes it.

Verified locally on the PR-2 build (build_pr2/, sm_90a single-arch):
  - All targeted PR-2 + ordering-guard tests pass (8/8): the existing
    OutputQK / SoftcapFloat32 / AsymmetricHeadSize / LargeHeadUnfused
    poison ordering guard, plus the 2 masked fp32 ordering tests added
    in the previous fix-up.
  - test_onnx_attention/test_gqa.py: 89/91 pass on a quiet GPU. The 2
    transient failures (FloatMaskDecode, MEAGQASoftcap softcap+mask
    decode) both pass cleanly when re-run in isolation; they are
    pre-existing run-to-run flakes (rtol=0/atol=0 strict-equality
    asserts) under shared-GPU pressure, not caused by this commit.
  - HS4 sweep class is gone (file count dropped from 96 to 91; the
    delta is exactly the 5 parameterized HS4 sweep cases, as
    expected).
  - Manual negative test of the new validation:
      output_qk=None         -> 3 outputs (disabled, OK)
      output_qk=2            -> 4 outputs (kQKSoftCap, OK)
      output_qk=-1           -> AssertionError (OK)
      output_qk=4            -> AssertionError (OK)
      output_qk=5            -> AssertionError (OK)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
titaiwangms added a commit to titaiwangms/onnxruntime that referenced this pull request May 6, 2026
…SKILL.md cleanup)

User requested folding PR microsoft#28371 entirely into this PR to eliminate the merge-ordering hazard between the two PRs and present a coherent fix+tests+docs reviewer story.

Following the earlier d613966 migration of the output_qk plumbing and masked fp32 softcap ordering tests, this commit migrates the remaining microsoft#28371 content:

test_gqa.py:
- TestONNXAttentionGQAAsymmetricHeadSize (REG, 2 tests) — guards the silent-broken-output regression on GQA + asymmetric Q/V head sizes that was fixed by PR microsoft#28358 (microsoft#28357). Pins the post-fix unfused-path behaviour on fp16 + bf16.
- TestONNXAttentionGQASoftcapFloat32 (1e baseline, 2 tests) — pins fp32 + softcap + GQA on the unfused path (MEA excludes is_gqa && fp32). Sibling to TestONNXAttentionGQASoftcapFloat32MaskOrdering already in this PR.

SKILL.md (cuda-attention-kernel-patterns):
- MEA eligibility paragraph: clarify that head_size%8 is enforced by has_memory_efficient_attention, and that head_size == v_head_size is required for GQA (LaunchUngroup) in addition to decode (LaunchConcatNewToPastKV). Reflects the post-microsoft#28358 host-side gate cleanup.

After this commit, PR microsoft#28371 is fully superseded; the lead will close it.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

CUDA Attention kernel crashes with mismatched Q/K head dimensions (head_size != v_head_size)

3 participants