Skip to content

Fix CPU Attention softcap/attn_mask ordering (onnx#7867, #7913) + consolidate CUDA spec coverage tests from #28371#28379

Open
titaiwangms wants to merge 8 commits intomicrosoft:mainfrom
titaiwangms:fix-cpu-attention-softcap-mask-ordering
Open

Fix CPU Attention softcap/attn_mask ordering (onnx#7867, #7913) + consolidate CUDA spec coverage tests from #28371#28379
titaiwangms wants to merge 8 commits intomicrosoft:mainfrom
titaiwangms:fix-cpu-attention-softcap-mask-ordering

Conversation

@titaiwangms
Copy link
Copy Markdown
Contributor

@titaiwangms titaiwangms commented May 5, 2026

Fix CPU Attention softcap/attn_mask ordering to match ONNX v24 spec (#7867, #7913) + consolidate CUDA spec coverage tests (#28371)

Branch: fix-cpu-attention-softcap-mask-ordering (off microsoft/onnxruntime main)
Tip: bfe33d3d08

Supersedes #28370 and #28371.

We recommend closing both #28370 and #28371 in favor of this one.


Scope

  1. CPU spec-correct softcap+mask ordering fix (per Fix Attention op softcap ordering: apply before mask/bias onnx/onnx#7867) — the headline bug. The CPU attention.cc was applying attn_mask before softcap (which would clip a -inf masked logit to a finite value, leaking the masked V slot through softmax). Now applies scale → softcap → +attn_mask → softmax in spec order.

  2. qk_matmul_output_mode enum value swap (per Swap qk_matmul_output_mode 1 & 2 to match computation order onnx/onnx#7913). Breaking change to mode 12 semantics + matching C++ enumerator rename. CPU op fully implements the new ordering; CUDA op (Flash/MEA/Unfused dispatch) confirmed unchanged in behaviour.

  3. Tianlei blocker tests (CPU): mode-1+softcap differentiating, softcap+nonpad_kv_seqlen no-leakage. Both failed before this PR's fix.

  4. Comprehensive CUDA Attention spec coverage (consolidated from Add CUDA Attention spec-coverage tests (GQA asymmetric head-size, output_qk, fp32 softcap+mask ordering) #28371):

  5. SKILL.md updates (.agents/skills/cuda-attention-kernel-patterns/SKILL.md):

  6. ONNX node-test skip lists (transient, until ONNX v1.22+ is bundled):

    • C++ runner: onnxruntime/test/onnx/TestCase.cc::GetBrokenTests() — 20 attention entries.
    • Python runner: onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc::current_failing_tests — 7 CPU attention entries.

    Both block-comments cite Fix Attention op softcap ordering: apply before mask/bias onnx/onnx#7867 + Swap qk_matmul_output_mode 1 & 2 to match computation order onnx/onnx#7913 and the v1.22+ unblock criterion: the fixtures bundled in v1.21.0 were generated before the spec PRs landed and therefore encode the now-incorrect ordering. Once cmake/external/onnx is bumped to ≥1.22, both skip blocks should be removed in one commit.


Cross-references


Verification

  • AttentionTest.* — 60/60 pass (gtest, CPU + CUDA).
  • TestONNXAttentionCPUSoftcapMaskOrdering — 4/4 pass (CPU spec fix + Tianlei blockers).
  • TestONNXAttentionGQAAsymmetricHeadSize — 2/2 pass (fp16 + bf16, H100).
  • TestONNXAttentionGQAOutputQK — 1/1 pass (fp16, H100).
  • TestONNXAttentionGQASoftcapFloat32 + …MaskOrdering — 4/4 pass (fp32, H100).
  • ONNX node tests: previously-failing 20 (C++) + 7 (Python) attention fixtures now SKIPPED with citation, not failed.
  • lintrunner clean on all touched files.

titaiwangms and others added 3 commits May 5, 2026 20:57
… (canary)

This commit adds 4 new tests (2 Python, 2 C++) that verify ONNX Attention
opset 23/24 spec ordering -- scale*QK -> softcap -> add bias/mask -> softmax,
per onnx/onnx#7867 (which superseded the now-closed onnx/onnx#7865 issue) and
onnx/onnx#7913 (which swapped qk_matmul_output_mode values 1 and 2 to align
with the corrected pipeline).

The tests are written using the small-softcap + poison-V technique already
established by the existing CUDA-only guards at:
  - test/python/transformers/test_onnx_attention/test_gqa.py:1501
    (test_gqa_large_head_unfused_softcap_additive_mask_poison_fp16)
  - test/python/transformers/test_onnx_attention/test_gqa.py:1761
    (test_mea_gqa_softcap_mask_ordering_no_leakage_prompt_fp16)

If softcap is applied AFTER mask-add, then tanh(-inf/softcap)*softcap =
-softcap (a finite value), which leaks probability through softmax to the
masked position. With V=1000 placed at the masked position, the wrong order
produces output ~155 (C++) / ~120 (Python) instead of the spec-correct ~0.2.

CANARY EVIDENCE (this commit, no production code change):

  C++ AttentionTest.Attention_Unfused_Softcap_NegInfMask_PoisonV_CPU
    FAILED -- cur_actual = 155.531, cur_expected = 0.200, delta 155.33

  C++ AttentionTest.Attention_Unfused_Softcap_NegInfMask_PoisonV_CUDA
    PASSED (sentinel: CUDA was already spec-correct)

  Python TestONNXAttentionCPUSoftcapMaskOrdering
    .test_cpu_attention_softcap_additive_mask_poison_prompt_fp32
    FAILED -- max |output| = (above 50 threshold)

  Python TestONNXAttentionCPUSoftcapMaskOrdering
    .test_cpu_attention_softcap_mask_ordering_no_leakage_prompt_fp32
    FAILED -- max |output| = 120.83

  Python existing CUDA guards (sentinel sanity):
    test_gqa_large_head_unfused_softcap_additive_mask_poison_fp16  PASSED
    test_mea_gqa_softcap_mask_ordering_no_leakage_prompt_fp16      PASSED

Also refreshes the SKILL.md citations from onnx/onnx#7865 to onnx/onnx#7867
+ onnx/onnx#7913 in section 1 (MEA eligibility) and section 4 (Softcap
Ordering). Section 4 is rewritten to spell out the full pipeline and to
reference the new CPU-side guard tests.

The CPU production fix that flips these CPU canaries from FAIL to PASS
is intentionally split into the next commit, so CI publicly records the
FAIL -> PASS transition and proves the new tests actually exercise the
ordering bug.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…microsoft#7913)

The CPU ONNX Attention op was applying mask/bias BEFORE softcap, violating
the v23/24 spec post onnx/onnx#7867 + onnx/onnx#7913. With finite softcap
active, -inf mask values were squashed by tanh into bounded -c, causing
probability leakage onto masked positions and arithmetic mixing of poison
V-values into the output (the canary tests in commit d76e45c demonstrate
the leak: max |y| > 100 with poison V at a -inf-masked position).

This is the inverse fix to PR microsoft#28370: CUDA was already correct per spec
(it splits the softcap snapshot before the mask add); CPU is the violator.

Changes:
- ComputeAttentionProbs<T>: refactor the per-head loop. When softcap is
  active, run GEMM with beta=0 (raw scale*Q*K^T), apply softcap inplace,
  then add mask explicitly via new AddInPlace<T> helper. When softcap is
  disabled, preserve the original beta=1 fold path (mask preloaded into C,
  FMA-accumulated) so pre-existing test calibrations remain numerically
  identical. Snapshot offsets for kQK / kPostSoftCap / kPostMaskBias /
  kPostSoftMax now follow the spec-correct stage order.
- AddInPlace<T>: new helper. Uses MlasEltwiseAdd<float> for fp32; FP16
  takes a scalar-fallback path because MlasEltwiseAdd<MLAS_FP16>'s
  dispatch->Add_Fp16 is not populated on all builds.
- attention_parameters.h: rename QKMatMulOutputMode enumerators per microsoft#7913
  numbering swap. Old: kQK=0, kQKMask=1, kQKSoftCap=2, kQKSoftMax=3.
  New:  kQK=0, kPostSoftCap=1, kPostMaskBias=2, kPostSoftMax=3.
- cuda/llm/attention.cc: update enum-tag references in comments and the
  qk_matmul_output_mode ENFORCE message (no logic change; CUDA was already
  spec-correct).
- attention_op_test.cc Attention4DWithPastAndPresentQkMatmul: regenerated
  expected y[] and qk_matmul[] arrays for the softcap=1.0 sub-call.
  Pre-fix arrays were calibrated against the OLD buggy ordering
  (softmax(softcap(raw+mask))) and necessarily change once the spec is
  honored (softmax(softcap(raw)+mask)). Five other sub-calls in the same
  test (modes -1/0/1/2/3, no softcap) are unchanged because mask add
  trivially commutes with the no-op softcap.

With this change, the canary tests added in d76e45c transition from
FAIL->PASS on CPU EP while CUDA EP guards continue to pass:
- AttentionTest.Attention_Unfused_Softcap_NegInfMask_PoisonV_CPU
- AttentionTest.Attention_Unfused_Softcap_NegInfMask_PoisonV_CUDA
- TestONNXAttentionCPUSoftcapMaskOrdering.test_cpu_attention_softcap_*

Verification: 58/58 AttentionTest.* + 19/19 GQATest.* + 217/218 Python
transformers/test_onnx_attention (the one failure is an unrelated
pre-existing tensorscatter validation test, unmodified by this PR).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…lded, doc polish

Applies the consolidated findings from the PR-1-v2 review pass
(lead-39245992/pr1v2-review-{code,critical,readability}.md, pr1v2-qa.md):

- (Critical M1, blocking) Independent oracle verification of the regenerated
  baseline values for `Attention4DWithPastAndPresentQkMatmul`'s softcap=1.0
  sub-call. attention_ref() (test_onnx_attention/common.py) was used as a
  spec-correct ground-truth oracle outside the just-modified CPU code path.
  Result: max |y - oracle| = 5.4e-7, max |qk_matmul - oracle| = 6.0e-7,
  3+ orders of magnitude inside the fp16 tolerance (atol=rtol=1e-3) used
  elsewhere in the suite. Regen confirmed spec-correct (not circular).
  Verification snippet preserved at lead-39245992/pr1v2-oracle-verify.py.
  A back-reference comment was added immediately above the regenerated
  arrays.

- (Critical M2) Breaking-change callout for the `qk_matmul_output_mode`
  enum value swap (onnx/onnx#7913) added to lead-39245992/pr1v2-description.md.
  Includes the old/new value table and notes that no in-tree consumer pins
  the old numbering.

- (Code minor + Readability M2) DRY the `mask_was_folded` predicate.
  The 5-clause GEMM-fold condition was previously duplicated at the if/else
  that selects beta and again at the post-softcap mask-add. Captured once
  into `fold_mask_into_gemm` at first use and reused; deleted the duplicate
  predicate. Single source of truth for "did the mask get baked into C via
  beta=1?". Comment block updated accordingly.

- (Readability M3) Reconciled AddInPlace<MLFloat16> rationale. The source
  comment now states the precise reason: MlasEltwiseAdd<MLAS_FP16> requires
  the per-platform EltwiseDispatch->Add_Fp16 kernel slot to be populated,
  and only the ARM NEON build provides it (see
  onnxruntime/core/mlas/lib/eltwise.cpp:92-103); x86 and other targets
  would throw at runtime. This matches the precise wording used in the
  original commit-2 message body.

Verification at HEAD:
- 58/58 AttentionTest.* PASS (incl. both PoisonV canaries +
  Attention4DWithPastAndPresentQkMatmul).
- All `*GQA*` test-suite tests PASS.
- 3/3 Python canaries (CPU additive mask, CPU mask ordering, CUDA mask
  ordering) PASS.
- Lintrunner clean.

Deferred per task brief:
- M4 (literal `1->2` justification at attention_op_test.cc:1782) - covered
  by the new regen comment block.
- Cross-PR SKILL.md \u00a71 coordination with PR microsoft#2 (readability M1) -
  handled at merge time.
- All reviewer nits.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Aligns ONNX Attention handling with the corrected softcap→mask/bias ordering and updates the auxiliary qk_matmul_output_mode naming/mapping used by the CPU/CUDA Attention codepaths and tests.

Changes:

  • Refactors the CPU Attention score pipeline so softcap is applied before additive mask/bias, with updated snapshot handling for output_qk.
  • Swaps/renames qk_matmul_output_mode enum values 1 and 2 to match the corrected stage ordering.
  • Adds/updates Python and C++ regression coverage plus related CUDA skill/docs comments.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py Adds CPU-side Python canaries for softcap/mask leakage behavior.
onnxruntime/test/providers/cpu/llm/attention_op_test.cc Updates qk_matmul baselines and adds C++ poison-V regression tests.
onnxruntime/core/providers/cuda/llm/attention.cc Renames accepted qk_matmul_output_mode enum references/messages on CUDA.
onnxruntime/core/providers/cpu/llm/attention.cc Implements the CPU production fix for score ordering and output snapshot staging.
onnxruntime/core/providers/cpu/llm/attention_parameters.h Renames/swaps QKMatMulOutputMode enum values to the corrected meanings.
.agents/skills/cuda-attention-kernel-patterns/SKILL.md Updates CUDA Attention skill guidance to reflect the corrected ordering/history.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 529 to 536
@@ -471,24 +536,19 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
mask_filter_value<T>());
Comment on lines 1550 to 1552
RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length,
q, k, v, m, std::initializer_list<bool>(), past_key, past_value,
-1, 1, std::numeric_limits<float>::quiet_NaN(), std::numeric_limits<float>::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
Comment thread onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py Outdated
Comment thread onnxruntime/test/providers/cpu/llm/attention_op_test.cc Outdated
Comment thread onnxruntime/core/providers/cpu/llm/attention_parameters.h Outdated
Comment thread onnxruntime/core/providers/cpu/llm/attention.cc Outdated
Copy link
Copy Markdown
Contributor Author

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review — multi-model pass (Opus 4.7 lead, with GPT-5.3-Codex / GPT-5.5 / Sonnet 4.6 reviewers)

Edit (post-discussion with author). Finding #1 below was originally framed as “add opset-version gating to preserve as-shipped opset-23 semantics.” That framing was wrong: opset 23 shipped with the leakage bug in both spec text and reference impl, and onnx/onnx#7867 + #7913 are post-release errata. Preserving the as-shipped opset-23 behavior would mean knowingly keeping the leak. The author’s position — that ORT should honor the corrected ordering/numbering for both v23 and v24 ahead of the ONNX v1.22 spec text — is the right call. Finding #1 is therefore reframed as a documentation ask, not a behavior change. Finding #2 (no differentiating test) becomes more important under this stance, since the test suite is the de-facto contract pinning ORT’s chosen interpretation.

The CPU softcap-before-mask fix itself is the right call: poison-V tests are an excellent oracle, the fold_mask_into_gemm refactor is clean, and the deep-copy-before-mutate snapshot discipline avoids the obvious alias trap. Mask handling for -inf is now correct.

That said, the cross-family review surfaced two non-trivial concerns and a handful of polish items.

🔴 Major

1. Document that ORT intentionally leads the spec on qk_matmul_output_mode semantics for opset 23.

The CPU kernel registers separate v23 and v24 kernels (cpu/llm/attention.cc:25-49), and CUDA does the same (cuda/llm/attention.cc:40-69), but both versions instantiate the same Attention<T> whose constructor reads the integer with no info.node().SinceVersion() check. Per onnx/onnx#7913 the integers 1 and 2 were swapped between the as-shipped opset 23 and the corrected pipeline, so an opset-23 model carrying value 1 will resolve to kPostSoftCap instead of the as-shipped “post-mask/bias” meaning.

This is intentional and correct — opset 23 shipped with both the wrong ordering and the wrong numbering, and #7867/#7913 are post-release errata that ORT is implementing ahead of ONNX v1.22. But the comment in attention_parameters.h:14 currently just says “v23/24 use the new numbering,” which reads as a passive statement of fact and obscures that ORT is deliberately diverging from the as-shipped opset-23 spec text.

Suggested fix: expand the comment to make the intent explicit, e.g. “ORT honors the corrected #7867/#7913 ordering/numbering for both opset 23 and opset 24, ahead of the ONNX v1.22 spec text errata. Producers that emitted opset-23 nodes against the as-shipped numbering will see the swapped auxiliary output post-merge; the as-shipped behavior is the leakage bug being fixed.” That makes the policy choice deliberate rather than accidental, and gives downstream tooling a clear pointer.

2. The enum swap is not actually exercised by a differentiating test.

attention_op_test.cc:1544-1571 updates the QK-output test, but the mode-1 case runs with softcap = NaN (i.e. softcap disabled), where kPostSoftCap and kQK are observationally identical. The mode-2 case is also softcap-off, so it only verifies the no-softcap commutative case. The riskiest behavior change in this PR — that mode 1 returns softcap(scale*QK) and mode 2 returns softcap(scale*QK) + mask/bias — is not pinned by any test.

This is more important than it would be under a strict version-gated implementation: because ORT is intentionally leading the spec, the test suite is the contract that documents the chosen interpretation. Without a differentiating test, a future refactor could quietly invert the snapshots and CI would not notice.

Suggested fix: one C++ CPU test with softcap > 0, non-zero additive mask, output-3 wired, and explicit expected values for modes 1 and 2. Bonus: a third case with nonpad_kv_seqlen enabled to lock in that kPostMaskBias includes nonpad masking (currently a defensible-but-undocumented choice — see polish item below).

🟡 Minor

3. Naming inconsistency in QKMatMulOutputMode (attention_helper.h).
Three of the four values follow a kPost* "stage you exited" convention; kQK is the lone outlier and its name implies "Q times K" even though scaling has already been applied. Renaming to kPostGemm or kPreSoftCap would make the enum self-describing. (Acknowledged: the names mirror the spec text, so this may be intentional — flagging for consideration.)

4. snapshot_needs_pre_mask conflates two orthogonal conditions (attention.cc:~153). The expression gates on both out_qk != nullptr and the mode, but the name only describes the mode half. Splitting into a mode-only predicate (mode_requires_pre_mask_snapshot) and combining with the buffer-presence check at the use site would make the intent obvious. Not a correctness bug, but a trap for the next maintainer.

5. Silent perf regression on the softcap + mask path (attention.cc:504-526). With softcap enabled, the mask can no longer fold into the GEMM beta=1, so the path becomes GEMM(beta=0) → in-place softcap → separate AddInPlace over the full S×T scores per head. That is one extra read/write pass over the score matrix — measurable for long-context CPU attention. Not a regression to revert (correctness wins), but worth a benchmark and possibly a fused softcap+mask pass for follow-up.

6. New MLFloat16 AddInPlace fallback lacks focused coverage (attention.cc:111-115, 148-151). The scalar MLFloat16(scores[i].ToFloat() + addend[i].ToFloat()) path only fires off-ARM-NEON, and fp16 numerics are exactly where ordering changes are most likely to surface small parity drifts. The new Python ordering tests are fp32-only by design (per the comment at test_gqa.py:1931-1936), so the fp16 path on x86 has no targeted softcap+mask test.

🟢 Polish

7. kPostMaskBias snapshot is taken after the nonpad_kv_seqlen mask (attention.cc:540-543). Defensible — “post-mask/bias” arguably encompasses nonpad masking — but worth a one-line comment so the next reader doesn’t wonder.

8. "Pre-fix / post-fix" language in test_gqa.py class docstring and SKILL.md. Useful for the PR conversation, confusing once merged. Suggest replacing with the evergreen invariant: "any regression to wrong ordering would produce max |output| ≈ poison_value." Same for the SKILL.md parenthetical.

9. Line-number references in SKILL.md (test_gqa.py:1501, :1761) will drift on the next change to that file. Reference the test function names instead.

10. MlasEltwiseAdd<float>(addend, scores, scores, count) aliased pointers. Verified safe — MLAS loads both operands before storing, and the same pattern exists at attention_helper.h:37 and in dynamic_quantize_matmul.cc:316. Not an issue, just noting it was checked.

✅ Praise

  • The fold_mask_into_gemm / softcap_active factoring with the 8-line "why" comment is the right density for a correctness-sensitive invariant.
  • Poison-V tests are a much stronger oracle than “high-softcap ≈ identity” checks.
  • Pipeline diagram in SKILL.md (stage 0→1→2→3 with the integer mapping right below) is the clearest part of the diff.
  • Deep-memcpy snapshots before in-place mutation correctly avoid the alias/lifetime traps that this kind of staging refactor often introduces.
  • Leading the spec on a clear correctness bug (rather than waiting for ONNX v1.22 to make the as-shipped opset-23 behavior officially wrong) is the right product call.

Verification done during review

  • All four qk_matmul_output_mode × softcap × mask combinations traced by hand — flow is internally consistent.
  • Repo grep + 712-model parse confirmed no in-tree serialized models carry qk_matmul_output_mode (so CI won’t catch a regression on the chosen interpretation — finding #2 is the safety net).
  • CUDA path: confirmed rename-only plus existing host guard rejecting modes beyond kNone/kQK; no new CUDA snapshot behavior introduced.
  • Confirmed onnx/onnx#7867 and #7913 are post-release errata to opset 23 — the ordering/numbering this PR implements matches the corrected spec text, not the as-shipped opset-23 spec.

Reviewers: Opus 4.7 (lead), GPT-5.3-Codex (code), GPT-5.5 (critical), Sonnet 4.6 (readability).

@tianleiwu tianleiwu changed the title Fix CPU Attention softcap/attn_mask ordering to match ONNX v24 spec (#7867, #7913) Fix CPU Attention softcap/attn_mask ordering to match ONNX v24 spec May 6, 2026
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary

The core fix is correct and well-motivated: the CPU EP now applies softcap before attn_mask/attn_bias, matching both the CUDA EP and the ONNX v23/v24 spec (onnx/onnx#7867 + #7913). The fold_mask_into_gemm refactoring is clean — it preserves FMA-fused numerics for the no-softcap path and only takes the explicit AddInPlace path when ordering matters.

Correctness notes (verified):

  • fold_mask_into_gemm guards are correct: disabled when softcap_active or snapshot_needs_pre_mask.
  • nonpad_kv_seqlen placement after softcap is correct — the overwrite with mask_filter_value ensures positions are effectively masked regardless of any prior softcap transformation.
  • AddInPlace<MLFloat16> scalar fallback has no overflow risk (intermediate computation in fp32, inputs bounded by softcap range).
  • Snapshot ordering matches the pipeline stages exactly.

Two test coverage gaps (endorsed from prior review threads) remain the main blocker for full confidence:

  1. Mode 1 (kPostSoftCap) is not differentiated from mode 0 (kQK) when softcap is active — the test uses softcap=NaN where they are observationally identical.
  2. No test combines nonpad_kv_seqlen with softcap > 0 to guard the reordered path.

These gaps are more important here than usual: because ORT intentionally leads the spec (applying #7913's numbering to both opset 23 and 24), the test suite is the de-facto contract for the chosen semantics.

Overall the fix is the right call. Addressing the two test gaps would let this merge with high confidence.

Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inline findings (posted as body due to GitHub API inline-comment issue on this PR):

onnxruntime/core/providers/cpu/llm/attention.cc line ~469 (fold_mask_into_gemm):
Good design: cleanly separates the legacy FMA path (no numeric regression) from the spec-correct path (softcap then explicit mask add). The three conditions are exactly the cases that must avoid folding:

  • softcap_active: mask must come after softcap per spec
  • snapshot_needs_pre_mask: kQK/kPostSoftCap snapshots must not include the mask

Verified correct.

onnxruntime/core/providers/cpu/llm/attention.cc line ~107 (AddInPlace):
The MLAS fp16 limitation rationale is well-documented. If a future commit adds x86 fp16 MLAS support, this fallback should be gated behind a runtime check rather than being unconditional.

onnxruntime/test/providers/cpu/llm/attention_op_test.cc line ~1552:
Endorsing the earlier feedback: since this test uses softcap = NaN (disabled), mode 1 (kPostSoftCap) produces the same tensor as mode 0. A differentiating test with softcap > 0 would lock in the new semantics.

titaiwangms and others added 2 commits May 6, 2026 16:48
Tianlei BLOCKER microsoft#1: New mode-1+softcap differentiating test (C++ + Python).
  With softcap > 0 active, qk_matmul_output_mode=1 (post-microsoft#7913 numbering =
  kPostSoftCap) snapshots softcap*tanh(scale*QK/softcap) with NO mask added.
  Without softcap, mode 1 aliases mode 0, so the swap is observationally
  indistinguishable — this test is what proves the 1<->2 swap actually
  changed semantics correctly.

Tianlei BLOCKER microsoft#2: New softcap+nonpad_kv_seqlen leakage test (C++ + Python).
  Exercises the latent fix where the nonpad sentinel is now applied AFTER
  softcap (per onnx#7867 ordering). Pre-fix: tanh squashed the sentinel,
  leaking poison V at padded positions through softmax.

Bot inline minors:
  - microsoft#3 (test_gqa.py): clarify fp16 docstring — CPU does support fp16; fp32 is
    the natural EP-native dtype for the canary.
  - microsoft#4 (attention_op_test.cc): regen comment now cites shared opset 23/24
    ordering and notes RunTest4D builds at opset 23.
  - microsoft#5 (attention_parameters.h): typo defintion -> definition.
  - microsoft#6 (attention.cc): replace 'guaranteed -inf' with precise wording citing
    mask_filter_value<T>() = numeric_limits::lowest() / MLFloat16::MinValue
    sentinel and the MLAS softmax finite-input requirement (attention.h).

R-2 microsoft#1 (attention_parameters.h): Spec-leading documentation block on the
  QKMatMulOutputMode enum noting that ORT now uses the post-onnx#7913
  numbering, while the bundled cmake/external/onnx (v1.21.0) still reflects
  the old numbering. ORT leads the spec change pending the next bundled-ONNX
  bump.

Plumbing: common.py attention_prompt_func gains an optional output_qk
  kwarg (default 0 / disabled). When > 0, returns a 4-tuple including the
  qk_matmul snapshot tensor; otherwise unchanged 3-tuple. No existing callers
  are affected.

Test results:
  - AttentionTest.* — 60/60 PASS (was 58, +2 new).
  - TestONNXAttentionCPUSoftcapMaskOrdering — 4/4 PASS (was 2, +2 new).
  - lintrunner clean across all 5 touched files.

Refs: lead-39245992/upstream-pr-status-recheck.md, pr1v2-review-{code,critical,readability,qa}.md

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…g bundled-ONNX update

Per architect 8b9842c3's recommendation (lead-39245992/pr1v2-onnx-fixture-handling.md):
mirror the existing lines 951-960 precedent ("Skipped until cmake/external/onnx
points to onnx 1.19 ... @onnx/onnxmicrosoft/pull/7074") and add a skip-with-cite block
for the attention fixtures regenerated upstream by onnx/onnx#7867 and
onnx/onnx#7913.

The bundled cmake/external/onnx is v1.21.0 (predates both PRs). Our impl
emits the corrected post-spec output, which disagrees with the still-old
fixtures shipped in v1.21.0. Skip until cmake/external/onnx is bumped to
>= v1.22, at which point the entries can be removed in a single cleanup
commit (greppable via 'v1.22 (includes onnx/onnx#7867').

20 entries added (10 base + 10 _expanded):
  - 4 softcap-related (cite onnx#7867)
  - 14 bias / qk_matmul_output_mode-related (cite onnx#7913)
  - 2 mask4d_padded_kv (cite onnx#7867 — same root cause; pre-existing
    QNN-only skip at line 1498 promoted to all providers)

Why not bump cmake/external/onnx instead: ONNX v1.22 has not shipped (latest
v1.21.0 = 2026-03-27; microsoft#7867 merged 2026-04-30, microsoft#7913 merged 2026-05-04). A
non-tagged SHA pin would cascade into opset registrations, fusion passes,
function-body decompositions, possibly opset-25 ops, and 80+ unrelated
fixture regenerations from microsoft#7867 alone — out of scope for a CPU behavioral
fix. Bump deserves its own dedicated PR.

Verification (./build/Linux/Debug/onnx_test_runner -e cpu -j 1
cmake/external/onnx/onnx/backend/test/data/node):
  - Pre-patch attention failures: 11 (10 from new-spec + 1 mask4d_padded_kv)
  - Post-patch attention failures: 0
  - Total cases: 1588 -> 1568 (20 skipped, matching added entries)
  - Only remaining failure: convinteger_with_padding (pre-existing, unrelated)
  - AttentionTest.* still 60/60 PASS
  - lintrunner clean

Refs: lead-39245992/pr1v2-onnx-fixture-handling.md (architect 8b9842c3)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms
Copy link
Copy Markdown
Contributor Author

Thanks @tianleiwu for the explicit blockers and to @copilot-pull-request-reviewer for the inline catches. Round addresses both blockers + every minor:

BLOCKER 1 — mode-1+softcap differentiating coverage. Without softcap > 0, the qk_matmul_output_mode 1↔2 enum swap is observationally equivalent to a rename — no scenario distinguishes the two integer mappings. Two new tests now do: C++ Attention_QkMatmulOutputMode_PostSoftCap_WithSoftcap_CPU and Python test_cpu_attention_qk_matmul_output_mode_post_softcap_with_softcap_fp32. They construct a case where mode 1 under post-#7913 numbering yields softcap * tanh(scale*QK/softcap) and under pre-#7913 numbering would yield [raw_qk, lowest()]. Both force CPU EP; snapshot tolerance 1e-5.

BLOCKER 2 — softcap + nonpad_kv_seqlen leakage. The latent ordering bug (sentinel pre-softcap → tanh squashes it → ~25% leakage to padded positions) is now pinned by C++ Attention_NonPadKVSeqLen_WithSoftcap_NoLeakage_CPU and Python test_cpu_attention_softcap_nonpad_kv_seqlen_no_leakage_prompt_fp32, using poison V (=1000) at padded positions; bounded output (≈ 1.0) confirms the post-fix ordering.

Spec-leading documentation (R-2 #1 / your spec-versioning concern). A 12-line IMPORTANT block now opens the QKMatMulOutputMode enum in attention_parameters.h stating ORT intentionally LEADS the bundled ONNX submodule (v1.21.0, predates onnx/onnx#7913); the v1.22 bump will retroactively make ORT spec-correct with no behavior change.

ONNX backend node tests. 20 fixtures in cmake/external/onnx encode the pre-#7867/#7913 ordering and disagree with our impl. Added them to TestCase.cc::GetBrokenTests with the unblock criterion cmake/external/onnx >= v1.22; pattern matches the existing cast_*_INT4 precedent (bare names, exact strings). Will be removed in the same PR that bumps the submodule.

Bot inlines all applied: typo defintion → definition, nonpad_kv_seqlen masking comment now correctly cites mask_filter_value<T>() (with reference to core/mlas/lib/eltwise.cpp for the MLAS finite-input rationale) instead of "guaranteed -inf", RunTest4D regen comment now notes the opset 23/24 shared-ordering, and the CPU softcap-mask docstring no longer claims fp16 unfused-attention is unsupported.

Verified locally: 60/60 AttentionTest.* PASS (was 58/58 + 2 new). Pre-patch 11 attention node-test failures → post-patch 0 (with the skip-list active). Zero CUDA production code changes — CPU-only fix. lintrunner clean.

Pushed as 2 additive commits, no force-push, so the review history stays intact:

  • 87c7e89 — tests + bot minors + spec-leading doc block
  • 0cd61f9 — ONNX backend node-test skip-list

…runner filter

Hot-fix completes the skip-list coverage started in 0cd61f9.

The previous commit added entries to onnxruntime/test/onnx/TestCase.cc::GetBrokenTests,
which is consumed by the C++ onnx_test_runner binary only. The Python
onnxruntime/test/python/onnx_backend_test_series.py wrapper around
onnx.backend.test.runner.Runner uses a SEPARATE filter file:
onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
('current_failing_tests' array). Both runners need their own exclusion.

Adds 7 CPU-suffixed entries (immediately after the existing CUDA counterparts
at lines 46-54) covering the same fixtures regenerated upstream by
onnx/onnx#7867 (softcap-then-mask ordering) and onnx/onnx#7913
(qk_matmul_output_mode 1<->2 numbering):
  - test_attention_3d_with_past_and_present_qk_matmul_bias_cpu
  - test_attention_3d_with_past_and_present_qk_matmul_softcap_cpu
  - test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_cpu
  - test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_cpu
  - test_attention_4d_with_past_and_present_qk_matmul_bias_cpu
  - test_attention_4d_with_qk_matmul_bias_cpu
  - test_attention_4d_with_qk_matmul_softcap_cpu

Removable in a single cleanup commit (greppable: 'pre-onnx#7867 fixture' /
'pre-onnx#7913 fixture') when cmake/external/onnx is bumped to v1.22+ — same
unblock criterion as 0cd61f9.

Verification:
  - JSONC parses cleanly (300 entries in current_failing_tests).
  - lintrunner clean.
  - 0 production code touched. Pure additive filter update.

Refs: lead-39245992/pr1v2-ci-failure-triage.md (architect 8b9842c3)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
titaiwangms and others added 2 commits May 6, 2026 18:30
…rosoft#28371

Migrate the output_qk plumbing in common.py (now `int | None`-typed with strict {0,1,2,3} validation) and three CUDA GQA tests from PR microsoft#28371 to this PR, where they semantically belong:

- The output_qk parameter numbering follows the post-onnx#7913 enum swap (kQK/kPostSoftCap/kPostMaskBias/kPostSoftMax) introduced in this PR.
- The masked fp32 softcap ordering tests pin the post-onnx#7867 'scale -> softcap -> +mask -> softmax' spec that this PR's CPU implementation enforces.

Migrated to test_gqa.py:
- TestONNXAttentionGQAOutputQK (1 test: GQA + raw QK output, fp16, unfused path)
- TestONNXAttentionGQASoftcapFloat32MaskOrdering (helper + 2 tests: symmetric and asymmetric-V-head poison-V tests on the unfused fp32 GQA path)

The unmasked fp32 GQA softcap baseline tests (TestONNXAttentionGQASoftcapFloat32) remain on microsoft#28371 — they are pure CUDA-side softcap coverage that does not depend on the spec ordering or the enum swap.

This migration also resolves the textual conflict between microsoft#28371 and microsoft#28379 in common.py and test_gqa.py, since the output_qk API can now only land once (here).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…SKILL.md cleanup)

User requested folding PR microsoft#28371 entirely into this PR to eliminate the merge-ordering hazard between the two PRs and present a coherent fix+tests+docs reviewer story.

Following the earlier d613966 migration of the output_qk plumbing and masked fp32 softcap ordering tests, this commit migrates the remaining microsoft#28371 content:

test_gqa.py:
- TestONNXAttentionGQAAsymmetricHeadSize (REG, 2 tests) — guards the silent-broken-output regression on GQA + asymmetric Q/V head sizes that was fixed by PR microsoft#28358 (microsoft#28357). Pins the post-fix unfused-path behaviour on fp16 + bf16.
- TestONNXAttentionGQASoftcapFloat32 (1e baseline, 2 tests) — pins fp32 + softcap + GQA on the unfused path (MEA excludes is_gqa && fp32). Sibling to TestONNXAttentionGQASoftcapFloat32MaskOrdering already in this PR.

SKILL.md (cuda-attention-kernel-patterns):
- MEA eligibility paragraph: clarify that head_size%8 is enforced by has_memory_efficient_attention, and that head_size == v_head_size is required for GQA (LaunchUngroup) in addition to decode (LaunchConcatNewToPastKV). Reflects the post-microsoft#28358 host-side gate cleanup.

After this commit, PR microsoft#28371 is fully superseded; the lead will close it.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms changed the title Fix CPU Attention softcap/attn_mask ordering to match ONNX v24 spec Fix CPU Attention softcap/attn_mask ordering (onnx#7867, #7913) + consolidate CUDA spec coverage tests from #28371 May 6, 2026
@titaiwangms titaiwangms requested a review from Copilot May 6, 2026 19:31
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.

qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kPostSoftCap ||
qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kPostMaskBias ||
qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kPostSoftMax,
"qk_matmul_output_mode must be 0, 1, 2, or 3.");
Comment on lines +1580 to +1584
// these reference values are stable across both. Independent oracle
// verification: see lead-39245992/pr1v2-oracle-verify.py
// (max |y - oracle| ~5e-7, max |qk_matmul - oracle| ~6e-7, well inside
// fp16 tolerance). Pre-fix values were calibrated to the buggy
// bias-before-softcap ordering.
@titaiwangms
Copy link
Copy Markdown
Contributor Author

Review (4-agent fan-out: code-reviewer + critical-reviewer + readability-reviewer + me)

🔴 Critical

None.

🟡 Major (worth addressing before merge)

  1. Missing test: kPostMaskBias snapshot + nonpad_kv_seqlen. The new code correctly takes the mode-2 snapshot AFTER nonpad masking is applied, but no test exercises that combination. A future regression that moves the snapshot earlier would slip past the Y-only assertions. Add one CPU test: softcap > 0 + nonpad_kv_seqlen < kv_seq_len + qk_matmul_output_mode=2 + assert padded columns hold the mask sentinel.

  2. Spec-leading essay needs an explicit TODO(onnx-v1.22) cleanup hook. The 25-line "IMPORTANT" comment in attention_parameters.h and the 18-entry skip block in TestCase.cc (and the matching 7-entry block in onnx_backend_test_series_filters.jsonc) explain why the divergence exists but give a future maintainer no actionable trigger. When cmake/external/onnx is bumped, all three blocks become silent stale debt. Add at the top of each:

    // TODO(onnx-v1.22): When cmake/external/onnx is bumped to a release that
    // includes onnx/onnx#7867 + #7913, (1) delete this block, (2) remove the
    // matching skip-list entries in TestCase.cc and
    // onnx_backend_test_series_filters.jsonc.
    

    Lowest-cost way to make the deliberate divergence retrievable when it matters.

  3. Confusing wording: "FMA-fused numerics that pre-spec-fix tests were calibrated against" (attention.cc:452-453). Implies FMA is mandatory and that tolerances would break on a non-FMA build. Suggested:

    "Keeps the numerical path identical to the pre-fix implementation on the no-softcap case, so pre-existing test expected values remain valid."

  4. Optional output_qk subtly perturbs primary Y numerics. fold_mask_into_gemm is disabled when kQK or kPostSoftCap is requested, so the same (softcap=0, mask present) input runs through beta=0 + explicit-add instead of beta=1 FMA. Mathematically equivalent but not bitwise — adds an "observer effect" to a diagnostic output. Either:

    • acknowledge with a short comment near snapshot_needs_pre_mask, or
    • add a regression test confirming Y matches with/without output_qk on the masked, softcap-off case.

🔵 Minor

  1. assert output_qk in (0,1,2,3) (common.py:170) is removed under python -O. Replace with if output_qk not in (0,1,2,3): raise ValueError(...).
  2. snapshot_needs_pre_mask (attention.cc:466) conflates two distinct semantics — kQK blocks the fold because it's pre-softcap (hence also pre-mask), while kPostSoftCap blocks because it's only pre-mask. Rename to snapshot_is_before_mask_stage or add an inline note clarifying both branches.
  3. SKILL.md "(post-fix; pre-fix)" reads like postfix/prefix operator notation. Reword to "before / after PR Fix CPU Attention softcap/attn_mask ordering (onnx#7867, #7913) + consolidate CUDA spec coverage tests from #28371 #28379".
  4. Stale line-number refs in test_gqa.py:1934-1937 ("at line 1501, … at line 1761") will go stale on any insert. Use Class.method names instead.
  5. Cross-EP softcap validation drift: CUDA enforces softcap >= 0, CPU silently treats negative as disabled (softcap_active = parameters.softcap > 0.0f). PR touches both files; trivially align by moving the check into the shared Attention<T> constructor or the helper.
  6. SKILL.md inline mode mapping (0 = …, 1 = …) should be a bullet list to match the rest of section 4's style — faster to skim.

↪️ Follow-up (out of scope, surfacing for tracking only)

  • ORT attention_helper.h rejects shorter attn_mask KV dim with nonpad_kv_seqlen, but ONNX v24 spec allows it. Real perf hit for users with large external KV caches.
  • nonpad_kv_seqlen == 0 produces inconsistent output across EPs (CPU = mean(V), CUDA MEA = zeros). Pre-existing TODO; pick one and enforce.
  • CUDA validates nonpad_kv_seqlen values via device asserts rather than host-side fail-closed validation — resilience concern under adversarial input.

✅ What's right

  • Fix logic is correct. fold_mask_into_gemm preserves the FMA-fused path on the no-softcap, no-snapshot case (where pre-existing tests live) and switches to explicit-add exactly when spec correctness demands it.
  • Enum rename is clean across all consumers — no stale kQKMask/kQKSoftCap/kQKSoftMax references in ORT production code.
  • CUDA attention.cc diff is genuinely cosmetic-only — no behavior change, as claimed.
  • AddInPlace<MLFloat16> doc comment (attention.cc:97-105) is exemplary: names the failing function, the build condition, file:line for context, and the runtime consequence.
  • ASCII pipeline diagram in SKILL.md §4 is the clearest single artifact in the PR.
  • Skip-list rationale comments correctly use "expected" and cite both upstream PRs, preventing future triage as bugs.
  • Spec-leading the kernel ahead of the bundled ONNX submodule is an intentional design choice (confirmed) — the v1.22 bump becomes a no-op.

Recommendation

Approve once Major #1 and Major #2 are addressed. Major #3 and #4 are wording/test-hygiene improvements. Minors are take-or-leave.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Softcap in Attention op

3 participants