Fix CUDA Attention dispatch: skip MEA when head_size != v_head_size in GQA#28358
Merged
justinchuby merged 1 commit intomainfrom May 5, 2026
Merged
Fix CUDA Attention dispatch: skip MEA when head_size != v_head_size in GQA#28358justinchuby merged 1 commit intomainfrom
justinchuby merged 1 commit intomainfrom
Conversation
…n GQA The Memory-Efficient Attention (MEA) path in the CUDA Attention kernel crashes with misaligned address when q_num_heads != kv_num_heads (GQA mode) and head_size != v_head_size. This happens because MEA's LaunchUngroup requires equal head sizes, but the dispatch logic only checked this constraint for the past_key case, not the general GQA case. Add the missing check: skip MEA for GQA when head_size != v_head_size, allowing the Unfused Attention fallback to handle it correctly. This fixes Gemma4 models with KV sharing where Q has head_dim=256 but shared K,V have head_dim=512. CPU EP handled this correctly; CUDA EP crashed at seq_len >= 4 when Flash Attention was not eligible. Fixes #28357 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <justinchu@microsoft.com>
tianleiwu
approved these changes
May 5, 2026
Contributor
tianleiwu
left a comment
There was a problem hiding this comment.
Review Summary
Clean, well-targeted one-line fix for a CUDA crash (cudaErrorMisalignedAddress) when MEA is used with GQA and mismatched Q/KV head dimensions. The root cause analysis is accurate — LaunchUngroup uses float2* reinterpret casts and a single head_size, so it requires head_size == v_head_size. The fix correctly gates MEA eligibility, and the unfused fallback explicitly supports this configuration.
Positives:
- The new condition
(!is_gqa || head_size == v_head_size)is logically independent from the existing(past_key == nullptr || head_size == v_head_size)— each guards a distinct internal function (LaunchUngroupvsLaunchConcatNewToPastKV), and both are needed. - Updated comments clearly document the two separate failure modes.
- Defense-in-depth
ORT_ENFORCEat line 722 insideRunMemoryEfficientAttentionremains as a safety net.
Contributor
|
I will add tests. |
This was referenced May 5, 2026
titaiwangms
added a commit
to titaiwangms/onnxruntime
that referenced
this pull request
May 5, 2026
…coverage Addresses microsoft#28351 sub-items REG, HS4, 1c, 1e: * HS4 (production): Add (head_size % 4 == 0) clause to the MEA dispatch predicate at core/providers/cuda/llm/attention.cc as forward-looking defense-in-depth. The clause is REDUNDANT TODAY: has_memory_efficient_ attention already enforces (qk_head_size & 7) == 0 (i.e. % 8) at contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h:71-72, which strictly implies % 4. We are not closing a current bug. The clause is kept as the strictest dtype-agnostic alignment floor that CUTLASS FMHA's BiasLoader actually requires (BiasLoader hardcodes a 128-bit / sizeof_bits-element alignment on Q/bias loads — 4 elements for fp32). Once microsoft#28365 lands and BiasLoader switches to kAlignmentA / DispatchIsAligned, MEA's own % 8 invariant will be loosened and this clause becomes load-bearing, preventing a correctness regression. The new comment block at the predicate site cites microsoft#28365 so the next maintainer can identify the right moment to delete it. * REG (test): TestONNXAttentionGQAAsymmetricHeadSize pins the asymmetric v_head_size != head_size GQA path on fp16 and bf16 to guard against regression of the microsoft#28358 fix that removed the LaunchUngroup head_size == v_head_size ENFORCE. * HS4 (test): TestONNXAttentionGQAHeadSizeMod4 sweeps head_size in {6, 10, 12, 16, 24}. Today head_sizes 6/10/12 are filtered upstream by MEA's % 8 gate and take the unfused fall-through path; this test pins that fall-through stays numerically correct. 16/24 satisfy both % 8 and % 4 and exercise the MEA happy path. Once microsoft#28365 relaxes MEA's % 8 invariant, head_sizes 6/10 will start exercising the HS4 host-side gate directly. * 1c (test): TestONNXAttentionGQAOutputQK pins the GQA + qk_matmul_output_mode combination (kQK raw scaled QK output) which previously had no test coverage. Threads an optional output_qk parameter through common.py's create_attention_graph_prompt and attention_prompt_func. * 1e (test): TestONNXAttentionGQASoftcapFloat32 pins the fp32 + softcap + GQA combination (symmetric and asymmetric V head). fp32 GQA always falls through to the unfused path on CUDA; existing softcap tests are fp16/bf16, so the fp32 unfused softcap branch had no parity coverage. All 10 new tests pass on H100 (sm_90a). Full test_gqa.py: 91 passed, 3 pre-existing flakes ('Output mismatch between two runs' determinism checks in unrelated decode-flash classes — not regressions from this change). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
titaiwangms
added a commit
to titaiwangms/onnxruntime
that referenced
this pull request
May 5, 2026
… tighten output_qk validation Three substantive items + one docstring fix from round-2 reviewer feedback (bot + internal multi-reviewer consolidation). * core/providers/cuda/llm/attention.cc — drop the host-side `head_size %% 4 == 0` HS4 clause from the MEA-eligibility predicate and remove its multi-paragraph comment block. The clause is fully redundant today (`has_memory_efficient_attention` already requires `(qk_head_size & 7) == 0`, which strictly implies %4) and the comment it carried made dtype-aware alignment claims that are wrong for fp16 / bf16 (BiasLoader needs an 8-element stride, not 4, for those dtypes). The dtype-aware alignment floor properly belongs in the BiasLoader fix (microsoft#28365), not as a vestigial redundant clause here. Predicate is now exactly the upstream/main shape for HS4 purposes. * test/python/transformers/test_onnx_attention/test_gqa.py — delete TestONNXAttentionGQAHeadSizeMod4. With the HS4 clause gone there is no host-side gate left to validate; the parameterized sweep was exercising routing equivalence vs the unfused fall-through, which is already covered by the broader MEA / unfused tests. * test/python/transformers/test_onnx_attention/common.py — tighten the output_qk parameter validation to `output_qk in {0, 1, 2, 3}` or `None`. The previous `is not None and >= 0` guard caught the C++ `kNone = -1` sentinel but still silently accepted invalid modes 4 / 5, which would build an ONNX graph with an out-of-range `qk_matmul_output_mode` attribute and let the unfused CUDA kernel populate the 4th output as raw-QK regardless. Validation now raises immediately with a clear message at the helper boundary; the binding-allocation site downstream is simplified to `is not None` since validation has already happened. NOTE block + both helper docstrings updated to spell out the contract: `None` = disabled; `{0,1,2,3}` = the corresponding QKMatMulOutputMode; anything else raises. * test/python/transformers/test_onnx_attention/test_gqa.py — fix the TestONNXAttentionGQAAsymmetricHeadSize docstring. The pre-microsoft#28358 `head_size == v_head_size` ENFORCE in LaunchUngroup is an MEA-path enforcement (LaunchUngroup is the GQA head-expansion helper used by MEA before its FMHA kernel), not an unfused-path one. Docstring now correctly attributes it. Verified locally on the PR-2 build (build_pr2/, sm_90a single-arch): - All targeted PR-2 + ordering-guard tests pass (8/8): the existing OutputQK / SoftcapFloat32 / AsymmetricHeadSize / LargeHeadUnfused poison ordering guard, plus the 2 masked fp32 ordering tests added in the previous fix-up. - test_onnx_attention/test_gqa.py: 89/91 pass on a quiet GPU. The 2 transient failures (FloatMaskDecode, MEAGQASoftcap softcap+mask decode) both pass cleanly when re-run in isolation; they are pre-existing run-to-run flakes (rtol=0/atol=0 strict-equality asserts) under shared-GPU pressure, not caused by this commit. - HS4 sweep class is gone (file count dropped from 96 to 91; the delta is exactly the 5 parameterized HS4 sweep cases, as expected). - Manual negative test of the new validation: output_qk=None -> 3 outputs (disabled, OK) output_qk=2 -> 4 outputs (kQKSoftCap, OK) output_qk=-1 -> AssertionError (OK) output_qk=4 -> AssertionError (OK) output_qk=5 -> AssertionError (OK) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
titaiwangms
added a commit
to titaiwangms/onnxruntime
that referenced
this pull request
May 6, 2026
…SKILL.md cleanup) User requested folding PR microsoft#28371 entirely into this PR to eliminate the merge-ordering hazard between the two PRs and present a coherent fix+tests+docs reviewer story. Following the earlier d613966 migration of the output_qk plumbing and masked fp32 softcap ordering tests, this commit migrates the remaining microsoft#28371 content: test_gqa.py: - TestONNXAttentionGQAAsymmetricHeadSize (REG, 2 tests) — guards the silent-broken-output regression on GQA + asymmetric Q/V head sizes that was fixed by PR microsoft#28358 (microsoft#28357). Pins the post-fix unfused-path behaviour on fp16 + bf16. - TestONNXAttentionGQASoftcapFloat32 (1e baseline, 2 tests) — pins fp32 + softcap + GQA on the unfused path (MEA excludes is_gqa && fp32). Sibling to TestONNXAttentionGQASoftcapFloat32MaskOrdering already in this PR. SKILL.md (cuda-attention-kernel-patterns): - MEA eligibility paragraph: clarify that head_size%8 is enforced by has_memory_efficient_attention, and that head_size == v_head_size is required for GQA (LaunchUngroup) in addition to decode (LaunchConcatNewToPastKV). Reflects the post-microsoft#28358 host-side gate cleanup. After this commit, PR microsoft#28371 is fully superseded; the lead will close it. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Problem
The Memory-Efficient Attention (MEA) path crashes with
cudaErrorMisalignedAddresswhen:q_num_heads != kv_num_heads)head_size != v_head_size(e.g., Q.head_dim=256, K.head_dim=512)seq_len >= 4(Flash Attention not eligible due to attention mask)This is because MEA's
LaunchUngrouprequires equal head sizes, but the dispatch logic only checked this constraint for the past_key case (line 1380), not the general GQA case.Fix
Skip MEA for GQA when head sizes differ. The Unfused Attention fallback handles this correctly.
Affected Models
Gemma 4 was not affected. This was a previously incorrect graph. But the fix is still good to have that improves robustness anyways.
Gemma4 (google/gemma-4-e2b-it) with KV sharing:head_size = 256,v_head_size = 512Testing
Minimal repro (from #28357):
Full Gemma4 decoder (35 layers, 15 GQA + 20 standard Attention):
Fixes #28357