Skip to content

Add CUDA Attention spec-coverage tests (GQA asymmetric head-size, output_qk, fp32 softcap+mask ordering)#28371

Closed
titaiwangms wants to merge 5 commits intomicrosoft:mainfrom
titaiwangms:fix-attention-spec-gaps-and-gqa-tests
Closed

Add CUDA Attention spec-coverage tests (GQA asymmetric head-size, output_qk, fp32 softcap+mask ordering)#28371
titaiwangms wants to merge 5 commits intomicrosoft:mainfrom
titaiwangms:fix-attention-spec-gaps-and-gqa-tests

Conversation

@titaiwangms
Copy link
Copy Markdown
Contributor

@titaiwangms titaiwangms commented May 5, 2026

Adds CUDA Attention spec-coverage tests around the ONNX Attention op (opset 23/24) — closes the test-side gap surfaced by #28358 and fills sub-items REG, 1c, and 1e of the audit in #28351 / #28357. Test-only + test-helper API tightening + one SKILL.md edit. Zero production kernel changes.

What changed (3 files)

File Change
onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py +3 test classes / 7 cases (see below)
onnxruntime/test/python/transformers/test_onnx_attention/common.py output_qk parameter contract: None disables the 4th output; 0..3 enables a specific kQK / kPostSoftCap / kPostMaskBias / kPostSoftMax mode. Anything else (-1, 4, 5, …) raises AssertionError. Replaces the prior truthy/is not None check that silently bound the 4th output for the kNone = -1 sentinel.
.agents/skills/cuda-attention-kernel-patterns/SKILL.md §1 cleanup: removed the stale HS4 reference (the host-side head_size % 4 gate was dropped during review — see "Dropped during review" below).

Tests added

  • REG — TestONNXAttentionGQAAsymmetricHeadSize — exercises GQA with qk_head_size != v_head_size on the MEA path's LaunchUngroup helper. Closes the test gap that hid the regression fixed in Fix CUDA Attention dispatch: skip MEA when head_size != v_head_size in GQA #28358.
    • test_gqa_asymmetric_v_head_size_prompt_fp16
    • test_gqa_asymmetric_v_head_size_prompt_bf16
  • 1c — TestONNXAttentionGQAOutputQK — exercises the GQA + 4th-output (output_qk) plumbing on the unfused CUDA path with qk_matmul_output_mode = kQK.
    • test_gqa_output_qk_raw_prompt_fp16
  • 1e — TestONNXAttentionGQASoftcapFloat32 — pins fp32 GQA + softcap behaviour, including the mask-ordering invariant (poison-V at masked positions; with the post-Topo sort the model before saving #7913 ordering the mask is applied after softcap, so the bounded output proves correct ordering).
    • test_gqa_softcap_fp32_symmetric
    • test_gqa_softcap_fp32_asymmetric_v_head
    • test_gqa_softcap_fp32_with_mask_ordering_symmetric
    • test_gqa_softcap_fp32_with_mask_ordering_asymmetric_v_head

7 new test cases.

Dropped during review

  • HS4 (host-side head_size % 4 gate in core/providers/cuda/llm/attention.cc) — the original PR added a host-side floor before MEA dispatch. Two reviewer streams (multi-model internal + copilot bot) converged on the same finding: today the gate is redundant with the existing upstream (qk_head_size & 7) == 0 invariant in cutlass_fmha/memory_efficient_attention.h, and the % 4 floor is dtype-incorrect for fp16/bf16. BiasLoader does a 16-byte wide load per row → fp32 needs head_size % 4 == 0, fp16/bf16 need head_size % 8 == 0, generalized as head_size % (16 / sizeof(T)) == 0. The dtype-aware floor properly belongs in Fix CUDA EP: opset 24 kernel registrations + CUTLASS alignment + MEA dispatch #28365 (which owns the alignment loosening in BiasLoader → kAlignmentA). Posted a heads-up there with the precise math: Fix CUDA EP: opset 24 kernel registrations + CUTLASS alignment + MEA dispatch #28365 (review).

Deferred to follow-up

  • 1d — opset-24 short-attn_mask auto-padding. ONNX opset 24 allows attn_mask.shape[-1] < total_sequence_length (trailing positions auto-pad with -inf, i.e., short attn_mask used as a right-padding indicator). Today both CPU and CUDA hard-ORT_ENFORCE strict equality (see existing TODO at onnxruntime/core/providers/cpu/llm/attention_helper.h:141-147). User-visible impact: spec-legal models with a short attn_mask are rejected with inconsistent total_sequence_length…loud failure, not silent wrong output.

Cross-references

Process notes

  • 4 additive commits over the review cycle (b4d7dbf59b → cb8bd5902a → e8a8db723a → 99a48bc689 → 449fa8cfa5); no force-push, so review history is intact.
  • Verified: lintrunner clean at HEAD 449fa8cfa5.

…coverage

Addresses microsoft#28351 sub-items REG, HS4, 1c, 1e:

* HS4 (production): Add (head_size % 4 == 0) clause to the MEA dispatch
  predicate at core/providers/cuda/llm/attention.cc as forward-looking
  defense-in-depth. The clause is REDUNDANT TODAY: has_memory_efficient_
  attention already enforces (qk_head_size & 7) == 0 (i.e. % 8) at
  contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h:71-72,
  which strictly implies % 4. We are not closing a current bug. The clause
  is kept as the strictest dtype-agnostic alignment floor that CUTLASS
  FMHA's BiasLoader actually requires (BiasLoader hardcodes a 128-bit /
  sizeof_bits-element alignment on Q/bias loads — 4 elements for fp32).
  Once microsoft#28365 lands and BiasLoader switches to
  kAlignmentA / DispatchIsAligned, MEA's own % 8 invariant will be loosened
  and this clause becomes load-bearing, preventing a correctness regression.
  The new comment block at the predicate site cites microsoft#28365 so the next
  maintainer can identify the right moment to delete it.

* REG (test): TestONNXAttentionGQAAsymmetricHeadSize pins the asymmetric
  v_head_size != head_size GQA path on fp16 and bf16 to guard against
  regression of the microsoft#28358 fix that removed the LaunchUngroup head_size ==
  v_head_size ENFORCE.

* HS4 (test): TestONNXAttentionGQAHeadSizeMod4 sweeps head_size in
  {6, 10, 12, 16, 24}. Today head_sizes 6/10/12 are filtered upstream by
  MEA's % 8 gate and take the unfused fall-through path; this test pins
  that fall-through stays numerically correct. 16/24 satisfy both % 8 and
  % 4 and exercise the MEA happy path. Once microsoft#28365 relaxes MEA's % 8
  invariant, head_sizes 6/10 will start exercising the HS4 host-side gate
  directly.

* 1c (test): TestONNXAttentionGQAOutputQK pins the GQA + qk_matmul_output_mode
  combination (kQK raw scaled QK output) which previously had no test
  coverage. Threads an optional output_qk parameter through common.py's
  create_attention_graph_prompt and attention_prompt_func.

* 1e (test): TestONNXAttentionGQASoftcapFloat32 pins the fp32 + softcap +
  GQA combination (symmetric and asymmetric V head). fp32 GQA always falls
  through to the unfused path on CUDA; existing softcap tests are fp16/bf16,
  so the fp32 unfused softcap branch had no parity coverage.

All 10 new tests pass on H100 (sm_90a). Full test_gqa.py: 91 passed,
3 pre-existing flakes ('Output mismatch between two runs' determinism
checks in unrelated decode-flash classes — not regressions from this
change).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR hardens CUDA ONNX Attention runner selection and expands Python-side GQA coverage around spec/dispatch edge cases. It fits into the CUDA attention stack by tightening MEA eligibility in attention.cc and extending the transformer parity tests/helpers used to validate ONNX Attention behavior.

Changes:

  • Adds a defensive head_size % 4 == 0 guard to CUDA MEA eligibility in onnxruntime/core/providers/cuda/llm/attention.cc.
  • Extends the shared Python ONNX-Attention test helper to optionally request/bind output_qk.
  • Adds new GQA regression/spec-coverage tests for asymmetric head sizes, head-size routing, output_qk, and fp32 softcap; updates the internal CUDA attention skill doc.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
onnxruntime/core/providers/cuda/llm/attention.cc Tightens MEA dispatch eligibility with an added host-side head-size alignment check and explanatory comments.
onnxruntime/test/python/transformers/test_onnx_attention/common.py Adds optional output_qk plumbing to the prompt-phase graph builder and execution helper.
onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py Adds new CUDA GQA regression/spec-coverage tests for asymmetric V heads, head-size routing, raw output_qk, and fp32 softcap.
.agents/skills/cuda-attention-kernel-patterns/SKILL.md Updates internal agent guidance for CUDA Attention MEA eligibility.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/providers/cuda/llm/attention.cc Outdated
Comment thread onnxruntime/test/python/transformers/test_onnx_attention/common.py Outdated
Comment thread onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py Outdated
Four targeted follow-ups to the bot review on
microsoft#28371. No production
behaviour change beyond the comment text and a Python helper guard.

* core/providers/cuda/llm/attention.cc — soften the HS4 deletion
  criterion. The original 'delete when MEA no longer requires the %8
  invariant' is necessary but not sufficient: removing the clause
  would also need every other host-side gate to keep head_size < 4
  out of LaunchUngroup, which still ORT_ENFORCEs head_size %% 4 == 0
  internally (see ~line 723-724). Cite microsoft#28365
  and the LaunchUngroup ENFORCE site explicitly.

* test/python/transformers/test_onnx_attention/common.py — fix
  output_qk negative-mode bug. Helper used 'output_qk is not None'
  to gate the optional 4th output, but a caller mirroring the C++
  enum convention (kNone = -1 in attention_parameters.h) would pass
  -1 and silently get the 4th output bound + the unfused CUDA kernel
  populating it as raw-QK. Tighten the gate to '>= 0' across all
  three sites (graph node, output binding, return tuple) and update
  the prominent NOTE block + docstrings to spell out the convention.

* test/python/transformers/test_onnx_attention/test_gqa.py — add
  test_gqa_softcap_fp32_with_mask_ordering_{symmetric,asymmetric_v_head}
  to TestONNXAttentionGQASoftcapFloat32. The existing fp32 softcap
  cases passed attn_mask=None, so they could not detect a wrong
  softcap-vs-mask order on the unfused fp32 path (without a mask the
  two orders are arithmetically identical). The new tests use the
  same poison-V pattern as the fp16/bf16 P1 ordering guards (small
  softcap, V=1000 in masked slot, attn_mask=-inf there) so wrong
  ordering produces wild magnitudes / NaN and right ordering yields
  bounded finite values. Compare against attention_ref().

* test/python/transformers/test_onnx_attention/test_gqa.py — correct
  the TestONNXAttentionGQAOutputQK docstring. The fp16/bf16 restriction
  applies only to the MEA LaunchUngroup helper, not the entire GQA-on-
  CUDA surface; the unfused fall-through DOES support fp32, exercised
  by TestONNXAttentionGQASoftcapFloat32 in the same file.

All four fixes verified locally:
  - 96/96 in test_onnx_attention/test_gqa.py pass (PR-2 build with
    HS4 dispatch + GQA-fp32 MEA exclusion).
  - 4/4 TestONNXAttentionGQASoftcapFloat32 pass (2 existing + 2 new
    masked ordering tests).
  - The P1-style ordering guard
    test_gqa_large_head_unfused_softcap_additive_mask_poison_fp16
    still passes; the common.py output_qk change does not affect
    paths that don't request output_qk.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms
Copy link
Copy Markdown
Contributor Author

Thanks @copilot for the careful pass — the negative-int output_qk issue (#2) and the softcap-without-mask test gap (#3) were both real and worth fixing. Pushed cb8bd5902a (additive commit, no force-push, so the review history stays intact):

  1. common.pyoutput_qk negative-mode hardening. enable_output_qk / output_qk_enabled now treat any negative int (matching the C++ kNone = -1 sentinel) as disabled at all three gating sites — graph outputs, IO binding, and the return-tuple shape — plus the qk_matmul_output_mode attribute itself is clamped to 0 when negative.
  2. test_gqa.py — fp32 softcap + mask ordering coverage. Added _with_mask_ordering_symmetric and _with_mask_ordering_asymmetric_v_head using the same poison-V pattern as the existing fp16/bf16 ordering guards: NaN check + max_abs < 1.0 leakage assertion + parity vs attention_ref() on the unfused fp32 GQA path that PR Set up CI with Azure Pipelines #1's softcap+mask reorder landed in.
  3. attention.cc — HS4 deletion comment. Tightened to require BOTH preconditions before removal: (a) MEA dropping its (qk_head_size & 7) == 0 invariant AND (b) every other host-side gate keeping head_size < 4 out of LaunchUngroup.
  4. test_gqa.pyTestONNXAttentionGQAOutputQK docstring. Scoped the fp16/bf16 requirement to the GQA+MEA path; the unfused fall-through does support fp32 (already exercised by TestONNXAttentionGQASoftcapFloat32).

No production CUDA logic changes — diff is 3 files (attention.cc comment-only, common.py guards, test_gqa.py tests + docstring).

@titaiwangms titaiwangms requested a review from Copilot May 5, 2026 22:56
Copy link
Copy Markdown
Contributor Author

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed by code, critical, and readability reviewers — all passed.

Verdict: Approve.

Guard is safe (no-op today, correctly forward-looking for #28365). API change is backward-compatible. Tests are solid — 10 cases covering asymmetric V heads, head_size sweep, output_qk GQA, and softcap fp32 with poison-V ordering validation.

One minor suggestion: The 24-line comment on the head_size % 4 == 0 clause is disproportionate for a currently-redundant guard. Consider trimming to ~8 lines and moving the detailed deletion criteria to the tracking issue (#28365).

Copy link
Copy Markdown
Contributor Author

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consolidated review

Reviewed by Opus 4.7 (lead) + GPT-5.3-Codex (code) + GPT-5.5 (adversarial) + Claude Sonnet 4.6 (readability). Strong consensus on three substantive issues.


🔴 Major — head_size % 4 is the wrong floor for fp16/bf16; comment overstates what #28365 does

File: onnxruntime/core/providers/cuda/llm/attention.cc:1383-1413, .agents/skills/cuda-attention-kernel-patterns/SKILL.md:10

The clause's own comment block states CUTLASS BiasLoader uses 128 / sizeof_bits<T> elements:

  • fp32 → 4 elements (so %4 is correct)
  • fp16/bf16 → 8 elements (so %4 is insufficient)

This means once #28365 lands and the upstream %8 gate is gone, fp16/bf16 with head_size = 4 would pass this guard but still violate the actual 8-element vectorized-load requirement. The clause therefore cannot serve as the "dtype-agnostic alignment floor" the comment claims it is.

Additionally — verifying the comment's premise — the public diff/metadata for #28365 changes BiasLoader to use kAlignmentA and adjusts bias-stride dispatch, but does not appear to remove the (qk_head_size & 7) == 0 invariant in has_memory_efficient_attention(). So the comment's claim that "MEA's %8 invariant goes away" once #28365 lands is not supported by that PR as currently written.

The combined effect: this clause is being inserted as the future source of truth for alignment, but it (a) doesn't match the actual dtype-dependent kernel requirement and (b) anchors itself to a future PR that doesn't actually own the invariant change. Future maintainers reading the comment may remove the %8 gate based on this clause and silently admit unsafe fp16/bf16 shapes.

Suggested fix (one of):

  • Make the guard dtype-aware: head_size % (16 / sizeof(T)) == 0, or wrap in a GetMinMEAHeadSizeAlignment<T>() helper that also lives next to the kernel traits.
  • Drop the clause; let the %8 gate stay until the same PR that relaxes CUTLASS alignment also relaxes MEA eligibility (single source of truth).
  • If keeping as-is, rewrite the comment to make clear this is a conservative under-approximation and not the full CUTLASS requirement, plus weaken the #28365 claim.

🟡 Major — TestONNXAttentionGQAHeadSizeMod4 doesn't prove the routing invariant it claims

File: onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py:2025-2090

It's a black-box parity test that admits in its own docstring that hs ∈ {6,10,12} are filtered by the upstream %8 gate today, so it currently exercises only the unfused fall-through. Three structural problems:

  1. No routing assertion — once %8 is relaxed, some other gate could still filter the same cases and this test would silently keep passing without ever exercising the new HS4 clause.
  2. attn_mask=None — does not hit the CUTLASS BiasLoader path that the C++ comment cites as the entire reason for this guard.
  3. No boundary cases — fp16/bf16 head_size = 4 is exactly the case that would prove or disprove the dtype-safety concern in the Major above. Missing.

Suggested fix: Add a routing observable (kernel-marker / profiling hook / log assertion) plus boundary cases (fp16 hs=4, fp16 hs=6, fp16 hs=8) and at least one variant with an additive mask so the BiasLoader path is actually exercised.


🟡 Minor — output_qk: int | None API now has a 3-way sentinel contract

File: onnxruntime/test/python/transformers/test_onnx_attention/common.py:100-177

Three meanings encoded in one parameter:

  • None → disabled (default)
  • any negative int → also disabled (defensive C++ kNone = -1 mirror)
  • >= 0 → enabled, value picks the mode

The fact that the docstring needs an 8-line NOTE block to disambiguate is the API smell. No caller in this diff actually passes -1, so the negative branch is speculative defensive code. Also, output_qk=0 previously meant "disabled" in this helper (if output_qk > 0), and now means "enabled raw-QK" — a silent semantics flip for any out-of-tree caller.

Suggested fix (any one):

  • Rename to output_qk_mode: int | None = None, document only the 2-way contract (None / non-negative int), and assert non-None values are in {0,1,2,3}. Drop the negative-int branch.
  • Or: reject negative values explicitly with a clear error message, rather than silently disabling them.
  • Add a focused test pinning the new contract: None → 3-tuple, 0 → 4-tuple raw-QK, invalid mode → assert/raise.

�� Minor — C++ comment buries the lede; cross-reference will go stale

File: onnxruntime/core/providers/cuda/llm/attention.cc:1383-1413

The 20-line block leads with five lines of BiasLoader mechanics; the operationally surprising fact ("this clause has no effect today") is in paragraph two. New readers will assume it's load-bearing. Also, ORT_ENFORCE near line 723-724 of this file will silently mislead after any surrounding edit.

Suggested 5-line restructure:

// Alignment floor for Cutlass FMHA's BiasLoader (16-byte loads = 4 fp32 / 8 fp16/bf16
// elements; head_size is the inner stride). NOTE: %4 covers fp32; fp16/bf16 actually
// need %8. Currently redundant — has_memory_efficient_attention() enforces %8.
// TODO(#28365): becomes load-bearing once that gate is relaxed AND a dtype-aware
// alignment helper replaces this hard-coded %4. See LaunchUngroup ENFORCE for fallback.
(parameters.head_size % 4 == 0);

🟢 Minor — Inconsistent naming + missing test coverage

  • enable_output_qk (line 92) vs output_qk_enabled (line 158) — same Boolean, two names. Pick output_qk_enabled (predicate form).
  • Asymmetric-GQA tests check out but ignore present_key / present_value; no decode-phase asymmetric coverage. Worth at least asserting present-KV shapes.
  • SKILL.md final sentence is one over-stuffed sentence with two parentheticals + a semicolon clause. Split into two.

✅ What's right

  • Falling through to unfused instead of allowing a later kernel launch failure is the correct architectural direction.
  • _run_softcap_fp32_with_mask poison-V pattern is exemplary — the "correct order vs wrong order" walkthrough plus the max_abs < 1.0 leakage assertion is exactly the right shape for a non-obvious test.
  • TestONNXAttentionGQAAsymmetricHeadSize docstring is a model: names the regression, cites the fix PR, explains the prior failure mode, states what's pinned.
  • Defaults preserve backward compat (3-tuple return).
  • Host-side change only — no thread-safety / perf concerns.

Recommendation

Request changes, primarily on the dtype-safety of the %4 floor and the comment's claim about #28365. The HS4 test should also gain a routing observable and the fp16/bf16 hs=4 boundary case before this clause can be meaningfully validated. Everything else is polish.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/test/python/transformers/test_onnx_attention/common.py Outdated
Comment thread onnxruntime/test/python/transformers/test_onnx_attention/common.py Outdated
Comment thread onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py Outdated
Comment thread onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py Outdated
titaiwangms and others added 2 commits May 5, 2026 23:37
… tighten output_qk validation

Three substantive items + one docstring fix from round-2 reviewer
feedback (bot + internal multi-reviewer consolidation).

* core/providers/cuda/llm/attention.cc — drop the host-side
  `head_size %% 4 == 0` HS4 clause from the MEA-eligibility predicate
  and remove its multi-paragraph comment block. The clause is fully
  redundant today (`has_memory_efficient_attention` already requires
  `(qk_head_size & 7) == 0`, which strictly implies %4) and the comment
  it carried made dtype-aware alignment claims that are wrong for fp16
  / bf16 (BiasLoader needs an 8-element stride, not 4, for those
  dtypes). The dtype-aware alignment floor properly belongs in the
  BiasLoader fix (microsoft#28365), not as a vestigial
  redundant clause here. Predicate is now exactly the upstream/main
  shape for HS4 purposes.

* test/python/transformers/test_onnx_attention/test_gqa.py — delete
  TestONNXAttentionGQAHeadSizeMod4. With the HS4 clause gone there is
  no host-side gate left to validate; the parameterized sweep was
  exercising routing equivalence vs the unfused fall-through, which is
  already covered by the broader MEA / unfused tests.

* test/python/transformers/test_onnx_attention/common.py — tighten the
  output_qk parameter validation to `output_qk in {0, 1, 2, 3}` or
  `None`. The previous `is not None and >= 0` guard caught the C++
  `kNone = -1` sentinel but still silently accepted invalid modes
  4 / 5, which would build an ONNX graph with an out-of-range
  `qk_matmul_output_mode` attribute and let the unfused CUDA kernel
  populate the 4th output as raw-QK regardless. Validation now raises
  immediately with a clear message at the helper boundary; the
  binding-allocation site downstream is simplified to `is not None`
  since validation has already happened. NOTE block + both helper
  docstrings updated to spell out the contract: `None` = disabled;
  `{0,1,2,3}` = the corresponding QKMatMulOutputMode; anything else
  raises.

* test/python/transformers/test_onnx_attention/test_gqa.py — fix the
  TestONNXAttentionGQAAsymmetricHeadSize docstring. The pre-microsoft#28358
  `head_size == v_head_size` ENFORCE in LaunchUngroup is an MEA-path
  enforcement (LaunchUngroup is the GQA head-expansion helper used by
  MEA before its FMHA kernel), not an unfused-path one. Docstring now
  correctly attributes it.

Verified locally on the PR-2 build (build_pr2/, sm_90a single-arch):
  - All targeted PR-2 + ordering-guard tests pass (8/8): the existing
    OutputQK / SoftcapFloat32 / AsymmetricHeadSize / LargeHeadUnfused
    poison ordering guard, plus the 2 masked fp32 ordering tests added
    in the previous fix-up.
  - test_onnx_attention/test_gqa.py: 89/91 pass on a quiet GPU. The 2
    transient failures (FloatMaskDecode, MEAGQASoftcap softcap+mask
    decode) both pass cleanly when re-run in isolation; they are
    pre-existing run-to-run flakes (rtol=0/atol=0 strict-equality
    asserts) under shared-GPU pressure, not caused by this commit.
  - HS4 sweep class is gone (file count dropped from 96 to 91; the
    delta is exactly the 5 parameterized HS4 sweep cases, as
    expected).
  - Manual negative test of the new validation:
      output_qk=None         -> 3 outputs (disabled, OK)
      output_qk=2            -> 4 outputs (kQKSoftCap, OK)
      output_qk=-1           -> AssertionError (OK)
      output_qk=4            -> AssertionError (OK)
      output_qk=5            -> AssertionError (OK)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
PR microsoft#28371 commit e8a8db7 dropped the host-side head_size %% 4 == 0
HS4 clause from the MEA-eligibility predicate in
core/providers/cuda/llm/attention.cc. This commit removes the now-
orphaned reference in the cuda-attention-kernel-patterns skill (§1
MEA eligibility bullet), which previously documented the HS4 floor
as a forward-looking alignment guard. The LaunchUngroup, decode,
GQA-fp32, and bias-stride clauses are unchanged.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms
Copy link
Copy Markdown
Contributor Author

Thanks @copilot for the second pass and to the multi-reviewer internal review for catching the same three issues — the convergence between the two streams was a strong signal these were real, not noise.

Round-2 fix-up landed as 2 additive commits (no force-push, so the review history stays intact): e8a8db723a + 99a48bc689.

  1. Dropped the HS4 host-side gate entirely. Both reviews correctly pointed out that head_size % 4 == 0 is the wrong floor for fp16/bf16 (BiasLoader's 16-byte load needs % (16 / sizeof(T)) — 4 elements for fp32, 8 for fp16/bf16), and the comment overstated what Fix CUDA EP: opset 24 kernel registrations + CUTLASS alignment + MEA dispatch #28365 does. The dtype-aware alignment floor is architecturally the responsibility of Fix CUDA EP: opset 24 kernel registrations + CUTLASS alignment + MEA dispatch #28365 (which owns the BiasLoader → kAlignmentA / DispatchIsAligned relaxation), not a host-side gate here. Removed: the clause in cuda/llm/attention.cc, the now-orphaned TestONNXAttentionGQAHeadSizeMod4 sweep, and the SKILL.md §1 citation. The MEA-eligibility predicate is now byte-equivalent to upstream on this front. Posted a heads-up on Fix CUDA EP: opset 24 kernel registrations + CUTLASS alignment + MEA dispatch #28365 with the precise math + a MinMEAHeadSizeAlignment<T>() suggestion for whoever loosens memory_efficient_attention.h:71: Fix CUDA EP: opset 24 kernel registrations + CUTLASS alignment + MEA dispatch #28365 (review)

  2. Tightened output_qk validation in test_onnx_attention/common.py from is not None and >= 0 to assert output_qk in {0, 1, 2, 3} (or None), with a !r-formatted error so type confusion ("1" vs 1) surfaces at the diagnostic boundary. All three binding sites (graph node outputs, IO binding, return-tuple shape) gate consistently off the post-validation enable_output_qk flag. -1 / 4 / 5 now raise at graph build instead of silently disabling or building invalid models.

  3. Fixed the TestONNXAttentionGQAAsymmetricHeadSize docstring to correctly attribute LaunchUngroup to the MEA path (not unfused) — LaunchUngroup is only used by RunMemoryEfficientAttention for K/V head expansion before the FMHA kernel.

Verified locally: 8/8 PR-2-relevant tests + the existing fp16/bf16 ordering guards pass. The two pre-existing flakes in the broader test_gqa.py sweep are unrelated (same flakes reproduce on upstream/main).

No production CUDA logic changes in this round-2; the diff is HS4 removal + output_qk validation + docstring/SKILL.md hygiene.

Collapses the multi-line parenthesized assertion message at
common.py:163-165 to a single line, per RUFF/Black formatter.
Assertion semantics unchanged (verified manually: None and
{0,1,2,3} accepted; -1, 4, 5 raise the same AssertionError
with the same message).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms changed the title Tighten CUDA Attention MEA eligibility (head_size % 4) and add ONNX-Attention spec-coverage tests Add CUDA Attention spec-coverage tests (GQA asymmetric head-size, output_qk, fp32 softcap+mask ordering) May 6, 2026
titaiwangms added a commit to titaiwangms/onnxruntime that referenced this pull request May 6, 2026
…rosoft#28371

Migrate the output_qk plumbing in common.py (now `int | None`-typed with strict {0,1,2,3} validation) and three CUDA GQA tests from PR microsoft#28371 to this PR, where they semantically belong:

- The output_qk parameter numbering follows the post-onnx#7913 enum swap (kQK/kPostSoftCap/kPostMaskBias/kPostSoftMax) introduced in this PR.
- The masked fp32 softcap ordering tests pin the post-onnx#7867 'scale -> softcap -> +mask -> softmax' spec that this PR's CPU implementation enforces.

Migrated to test_gqa.py:
- TestONNXAttentionGQAOutputQK (1 test: GQA + raw QK output, fp16, unfused path)
- TestONNXAttentionGQASoftcapFloat32MaskOrdering (helper + 2 tests: symmetric and asymmetric-V-head poison-V tests on the unfused fp32 GQA path)

The unmasked fp32 GQA softcap baseline tests (TestONNXAttentionGQASoftcapFloat32) remain on microsoft#28371 — they are pure CUDA-side softcap coverage that does not depend on the spec ordering or the enum swap.

This migration also resolves the textual conflict between microsoft#28371 and microsoft#28379 in common.py and test_gqa.py, since the output_qk API can now only land once (here).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
titaiwangms added a commit to titaiwangms/onnxruntime that referenced this pull request May 6, 2026
…SKILL.md cleanup)

User requested folding PR microsoft#28371 entirely into this PR to eliminate the merge-ordering hazard between the two PRs and present a coherent fix+tests+docs reviewer story.

Following the earlier d613966 migration of the output_qk plumbing and masked fp32 softcap ordering tests, this commit migrates the remaining microsoft#28371 content:

test_gqa.py:
- TestONNXAttentionGQAAsymmetricHeadSize (REG, 2 tests) — guards the silent-broken-output regression on GQA + asymmetric Q/V head sizes that was fixed by PR microsoft#28358 (microsoft#28357). Pins the post-fix unfused-path behaviour on fp16 + bf16.
- TestONNXAttentionGQASoftcapFloat32 (1e baseline, 2 tests) — pins fp32 + softcap + GQA on the unfused path (MEA excludes is_gqa && fp32). Sibling to TestONNXAttentionGQASoftcapFloat32MaskOrdering already in this PR.

SKILL.md (cuda-attention-kernel-patterns):
- MEA eligibility paragraph: clarify that head_size%8 is enforced by has_memory_efficient_attention, and that head_size == v_head_size is required for GQA (LaunchUngroup) in addition to decode (LaunchConcatNewToPastKV). Reflects the post-microsoft#28358 host-side gate cleanup.

After this commit, PR microsoft#28371 is fully superseded; the lead will close it.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms
Copy link
Copy Markdown
Contributor Author

Closing in favor of #28379, which now carries the consolidated scope.

Reason for consolidation: the two PRs ended up semantically coupled — the output_qk docstring/contract on this PR uses the pre-#7913 enum names that #28379 renames, and both PRs touch common.py, test_gqa.py, and .agents/skills/cuda-attention-kernel-patterns/SKILL.md with textual conflicts. Folding them into one PR eliminates merge-ordering risk and gives reviewers a single coherent fix + tests + docs change.

No coverage lost. All test classes added here now live on #28379:

  • TestONNXAttentionGQAAsymmetricHeadSize (REG)
  • TestONNXAttentionGQAOutputQK (1c)
  • TestONNXAttentionGQASoftcapFloat32 + the masked-ordering variants (1e)
  • common.py output_qk contract tightening
  • SKILL.md §1 HS4 cleanup

The HS4 drop and the dtype-aware-floor deferral to #28365 (review 4232187809) carry forward unchanged. The opset-24 short-attn_mask work (1d) remains a separate follow-up after #28379 lands.

Thanks @copilot-pull-request-reviewer for the two review rounds — every finding was carried forward and addressed in #28379. Branch left at 449fa8cfa5 for reference.

#28379

@titaiwangms titaiwangms closed this May 6, 2026
@titaiwangms titaiwangms deleted the fix-attention-spec-gaps-and-gqa-tests branch May 6, 2026 19:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants