Skip to content

Fix softcap/attn_mask ordering in CUDA Attention to match ONNX v24 spec and CPU EP#28370

Closed
titaiwangms wants to merge 1 commit intomicrosoft:mainfrom
titaiwangms:fix-softcap-attn-mask-ordering
Closed

Fix softcap/attn_mask ordering in CUDA Attention to match ONNX v24 spec and CPU EP#28370
titaiwangms wants to merge 1 commit intomicrosoft:mainfrom
titaiwangms:fix-softcap-attn-mask-ordering

Conversation

@titaiwangms
Copy link
Copy Markdown
Contributor

Branch: fix-softcap-attn-mask-ordering (off microsoft/onnxruntime main)
Commit: da4a0f6d30

Motivation

The ONNX opset 24 Attention operator specifies a fixed score-compute pipeline:

QKAttn  = scale * (Q @ K^T)
QKBias  = QKAttn + attn_bias        ← bias added FIRST
QKCap   = softcap * tanh(QKBias / softcap)   ← then softcap saturates
QKProb  = softmax(QKCap)
out     = QKProb @ V

(See cmake/external/onnx/onnx/defs/nn/defs.cc lines 3322–3341 for the ASCII
pipeline and lines 3657–3675 for the reference function graph.)

The CPU EP at core/providers/cpu/llm/attention.cc already implements this
order. Both CUDA paths had the order inverted: softcap was applied first,
and the bias was added on top of the [-cap, +cap] saturated values.

EP / kernel Pre-PR ordering Spec ordering
CPU core/providers/cpu/llm/attention.cc bias → softcap ✅ bias → softcap
CUDA unfused contrib_ops/cuda/bert/unfused_attention.cu softcap → bias ❌ bias → softcap
CUDA MEA contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h softcap → bias ❌ bias → softcap

Why this was not caught earlier

For the most common bias case — a hard causal/padding mask with
bias = std::numeric_limits<float>::lowest() — both orderings are
observationally equivalent: masked positions still drive the softmax weight
to ~0 either way. The bug only manifests when softcap > 0 is combined with
float-valued additive bias of moderate magnitude (for example,
ALiBi-style position biases as used in some Gemma-family models). In that
case the bias values escape the [-softcap, +softcap] saturation range,
producing a different softmax distribution than the spec.

Spec citation

cmake/external/onnx/onnx/defs/nn/defs.cc (bundled ONNX submodule, the
authoritative reference for the operator):

  • Lines 3322–3341 — schema doc for opset 24 Attention with the explicit
    Add(QKAttn, AttnBias) → softcap → softmax pipeline.
  • Lines 3657–3675 — reference function add_attn_bias/apply_softcap
    nodes wired in the same order.

Before / after numerical example

Setup (matches the new Attention4DSoftcapFloatBias_Unfused_FP32 test):

  • q = k = v = [1.0, 1.0, ..., 1.0] (single batch, 1 head, q_seq=1, kv_seq=2,
    head_size=8)
  • All-ones bias × -10.0 (i.e. attn_bias = -10 for every position)
  • softcap = 5.0

CPU / spec ordering (this PR, fp32):

score = scale * Q@K^T = (1/sqrt(8)) * 8 = 2.8284
+bias                 = 2.8284 + (-10)  = -7.1716
softcap(5, tanh)      = 5 * tanh(-7.1716 / 5) ≈ -4.4070
softmax(2 identical)  = [0.5, 0.5]
out @ V               = [1.0320, 1.0320, ...]   ← spec

Pre-PR CUDA ordering (softcap-then-bias):

score                 = 2.8284
softcap first         = 5 * tanh(2.8284/5) ≈ 2.5132
+bias                 = 2.5132 + (-10)     = -7.4868
softmax(2 identical)  = [0.5, 0.5]
out @ V               ≈ [1.0002, 1.0002, ...]   ← buggy

The output diverges by ~0.032 per element — well outside any sensible
fp32 tolerance and visible in fp16 too. With per-position varying bias
the softmax distribution itself drifts, not just the post-scaling.

Files changed

File Change
onnxruntime/contrib_ops/cuda/bert/unfused_attention.cu Swap softcap and bias-add blocks in all three passes (max / sum / normalize) of UnfusedSoftmaxKernel. Update doc comment.
onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h Swap softcap and bias-add blocks inside the score-compute lambda so bias lands in the accumulator fragment before the 1/cap, fast_tanh, *cap triple.
onnxruntime/core/providers/cuda/llm/attention.cc Update the inline comment in RunMemoryEfficientAttention describing the (now-fixed) ordering. No predicate / dispatch changeattention.cc:1383 is PR #2 territory.
.agents/skills/cuda-attention-kernel-patterns/SKILL.md §4 rewritten to assert the spec-correct ordering and cite defs.cc:3322-3341 / 3657-3675; previous claim (which referenced onnx/onnx#7865) kept as a "historical note" so future agents don't re-introduce it. §1 MEA-eligibility bullet amended to list LaunchUngroup's equal-head-size constraint alongside LaunchConcatNewToPastKV (closing the doc gap noted in PR #28358). §9 file-purpose row corrected.
onnxruntime/test/providers/cpu/llm/attention_op_test.cc Two new regression tests (see Test plan).

Total: 5 files, +175 / −35.

Test plan

New regression tests

Both are placed in test/providers/cpu/llm/attention_op_test.cc and follow
the existing RunTest4D / EnvVarMap patterns in that file.

  1. Attention4DSoftcapFloatBias_Unfused_FP32 — fp32, head_size=8,
    q_seq=1, kv_seq=2. The bias stride (kv_seq=2) is not a multiple of 4,
    so MEA's bias-alignment check rejects this shape and CUDA falls through
    to the unfused kernel. RunTest4D iterates over CPU and CUDA EPs, so
    the test asserts CPU↔CUDA parity at the spec value (Y ≈ 1.0320).
    Pre-PR CUDA value would be ~1.0002 — delta 0.032, vs the test's fp32
    tolerance of 3e-5.

  2. Attention4DSoftcapFloatBias_MEA_FP16 — fp16, head_size=8, q_seq=1,
    kv_seq=4 (bias stride aligned to 4). Forces MEA by setting
    kDisableFlashAttention="1" via EnvVarMap. Asserts the same spec
    output (Y ≈ 1.5200) against the buggy ~1.5001. Delta 0.0199, vs the
    test's fp16 tolerance of 5e-3.

Both expected outputs were hand-computed (see commit message and the
"before/after" section above).

Existing tests

The change preserves behaviour for hard masks (bias = lowest()), so
existing softcap+causal-mask coverage is unaffected. The pre-existing
suite at attention_op_test.cc exercises softcap with causal/padding
masks — those will continue to pass.

Local build / test status

  • Both modified CUDA kernels compile cleanly:
    • unfused_attention.cu.o — built ✅
    • cutlass_fmha/fmha_sm{50,70,75,80}.cu.o — built ✅
    • The new test .cc.o — built ✅
  • lintrunner -a on all five files: no issues.
  • Verified locally on H100 (sm_90a): all AttentionTest.* cases pass,
    including the two new regression tests
    (Attention4DSoftcapFloatBias_Unfused_FP32 and
    Attention4DSoftcapFloatBias_MEA_FP16).

Risk / blast radius

  • Risk: medium. This is a silent numerical change for one specific
    combination — softcap > 0 and float-valued additive attn_bias.
    Any user in that regime will now see different (spec-correct) outputs.
  • Hard-mask users (the common case): unaffected. Behaviour is
    observationally equivalent because the bias drives softmax to ~0
    regardless of order.
  • CPU EP: unchanged (already correct).
  • Other EPs (DML, CoreML, OpenVINO, WebGPU, QNN, JS, ...): out of scope
    — this PR fixes only the CUDA-EP kernels under
    contrib_ops/cuda/bert/. If any of those EPs has its own softcap
    kernel, it should be audited under a separate issue.
  • Performance: negligible. The change is a swap of two short blocks
    inside an already-iterated score-compute loop; instruction count is
    identical, no new memory traffic.
  • Skill doc edits are documentation-only.

Out of scope (handled by PR #2)

  • attention.cc:1383 — MEA-eligibility predicate / head_size % 4 gate.
  • GQA + output_qk validation, mask right-padding, the
    head_size != v_head_size GQA regression test, and the broader
    head_size ∈ {6, 10, 12, 16, 24} sweep.

Follow-ups (separate issues, intentionally not in this PR)

These were noted during review and are recorded here for traceability.
None block this PR.

  • CUTLASS kSupportsBias=false + softcap > 0
    (kernel_forward.h ~line 857-860): when no bias is supplied,
    p.scale is not pre-applied to the accumulator before the score
    loop, yet the softcap operates on the unscaled accum. The block
    fixed in this PR was always inside the attn_bias_ptr != nullptr
    branch, so the bug doesn't surface today. Filed as a follow-up
    because it's an independent latent issue that needs its own
    reproducer.
  • Cross-EP softcap-ordering audit: DML, CoreML, OpenVINO, WebGPU,
    QNN, and the JS EPs all implement Attention separately and may have
    the same bias/softcap inversion. Each needs its own audit + fix +
    test.
  • Additional regression coverage: ALiBi-style monotonically
    varying bias, bf16 variant of the existing tests, and a GQA softcap
    variant. The two tests added here are the minimum needed to lock
    the spec contract; broader coverage belongs to a follow-up.

…spec

The ONNX opset 24 Attention reference function in cmake/external/onnx/onnx/defs/nn/defs.cc (lines 3322-3341 ASCII pipeline, 3657-3675 reference graph) specifies the score-compute order as scale*Q@K^T -> +bias -> softcap -> softmax. The CPU EP at core/providers/cpu/llm/attention.cc already matches this. Both CUDA kernels (the unified unfused softmax kernel and the CUTLASS MEA fused score-compute loop) had the order inverted: they applied softcap before adding bias.

For hard masks (bias = lowest()) the two orderings are observationally equivalent because masked positions still drive softmax weight to ~0. For float biases of moderate magnitude (e.g. ALiBi-style position biases combined with softcap) the orderings differ materially: applying bias after softcap lets bias values escape the [-cap, +cap] saturation range.

Changes:
* unfused_attention.cu: swap softcap and bias-add in all three passes of UnfusedSoftmaxKernel; update doc comment.
* cutlass_fmha/kernel_forward.h: swap softcap and bias-add in score-compute so bias lands in accum before the 1/cap, fast_tanh, *cap triple.
* core/providers/cuda/llm/attention.cc: update inline comment in RunMemoryEfficientAttention. Eligibility predicate untouched (PR microsoft#2).
* .agents/skills/cuda-attention-kernel-patterns/SKILL.md: rewrite §4 to assert spec-correct ordering with citation; amend §1 MEA-eligibility bullet to list LaunchUngroup's equal-head-size constraint alongside LaunchConcatNewToPastKV.
* test/providers/cpu/llm/attention_op_test.cc: add Attention4DSoftcapFloatBias_Unfused_FP32 (CPU+CUDA parity) and Attention4DSoftcapFloatBias_MEA_FP16 (CUDA MEA, Flash forced off). Both use bias=-10, softcap=5 so spec vs buggy outputs differ ~0.02-0.03 per dim, well outside tolerance.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms requested a review from Copilot May 5, 2026 20:05
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Fixes CUDA Attention’s softcap and additive-bias application order to match the ONNX opset 24 Attention reference pipeline and CPU EP behavior, and adds regression coverage to prevent reintroductions.

Changes:

  • Reorders bias-add and softcap in CUDA unfused softmax (unfused_attention.cu) and CUTLASS MEA score path (kernel_forward.h) to follow scale → +bias → softcap → softmax.
  • Updates CUDA provider documentation/comments and the CUDA attention skill doc to reflect the opset 24 spec ordering.
  • Adds two regression tests targeting float-valued additive bias + softcap for both unfused CUDA and MEA CUDA paths.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
onnxruntime/test/providers/cpu/llm/attention_op_test.cc Adds two CUDA-focused regression tests for bias/softcap ordering in unfused and MEA paths.
onnxruntime/core/providers/cuda/llm/attention.cc Updates MEA path comment to document spec-correct ordering and cite ONNX reference.
onnxruntime/contrib_ops/cuda/bert/unfused_attention.cu Swaps bias/softcap application order across all three passes in UnfusedSoftmaxKernel.
onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h Moves bias-add ahead of softcap in the accumulator fragment before tanh saturation.
.agents/skills/cuda-attention-kernel-patterns/SKILL.md Updates internal guidance to assert opset 24 ordering and clarifies MEA constraints and file descriptions.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +3036 to +3044
// disable_dml=true: DML EP does not implement opset-24 Attention with float
// additive bias + softcap; excluding it keeps this test focused on
// CPU<->CUDA parity for the spec ordering fix.
RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length,
q, k, v, attn_mask, std::initializer_list<bool>(), std::vector<float>(), std::vector<float>(),
-1, -1, std::numeric_limits<float>::quiet_NaN(), 5.0f, -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
y, std::vector<float>(), std::vector<float>(), std::vector<float>(),
false, false, true // disable_cpu, disable_cuda, disable_dml
);
// produce the spec-correct output.
//
// Inputs are designed so the two orderings produce materially different Y:
// spec: Y[i] ≈ 1.0320 (softcap saturates the masked combined score to ≈ -5)
@titaiwangms titaiwangms closed this May 5, 2026
titaiwangms added a commit to titaiwangms/onnxruntime that referenced this pull request May 5, 2026
…microsoft#7913)

The CPU ONNX Attention op was applying mask/bias BEFORE softcap, violating
the v23/24 spec post onnx/onnx#7867 + onnx/onnx#7913. With finite softcap
active, -inf mask values were squashed by tanh into bounded -c, causing
probability leakage onto masked positions and arithmetic mixing of poison
V-values into the output (the canary tests in commit d76e45c demonstrate
the leak: max |y| > 100 with poison V at a -inf-masked position).

This is the inverse fix to PR microsoft#28370: CUDA was already correct per spec
(it splits the softcap snapshot before the mask add); CPU is the violator.

Changes:
- ComputeAttentionProbs<T>: refactor the per-head loop. When softcap is
  active, run GEMM with beta=0 (raw scale*Q*K^T), apply softcap inplace,
  then add mask explicitly via new AddInPlace<T> helper. When softcap is
  disabled, preserve the original beta=1 fold path (mask preloaded into C,
  FMA-accumulated) so pre-existing test calibrations remain numerically
  identical. Snapshot offsets for kQK / kPostSoftCap / kPostMaskBias /
  kPostSoftMax now follow the spec-correct stage order.
- AddInPlace<T>: new helper. Uses MlasEltwiseAdd<float> for fp32; FP16
  takes a scalar-fallback path because MlasEltwiseAdd<MLAS_FP16>'s
  dispatch->Add_Fp16 is not populated on all builds.
- attention_parameters.h: rename QKMatMulOutputMode enumerators per microsoft#7913
  numbering swap. Old: kQK=0, kQKMask=1, kQKSoftCap=2, kQKSoftMax=3.
  New:  kQK=0, kPostSoftCap=1, kPostMaskBias=2, kPostSoftMax=3.
- cuda/llm/attention.cc: update enum-tag references in comments and the
  qk_matmul_output_mode ENFORCE message (no logic change; CUDA was already
  spec-correct).
- attention_op_test.cc Attention4DWithPastAndPresentQkMatmul: regenerated
  expected y[] and qk_matmul[] arrays for the softcap=1.0 sub-call.
  Pre-fix arrays were calibrated against the OLD buggy ordering
  (softmax(softcap(raw+mask))) and necessarily change once the spec is
  honored (softmax(softcap(raw)+mask)). Five other sub-calls in the same
  test (modes -1/0/1/2/3, no softcap) are unchanged because mask add
  trivially commutes with the no-op softcap.

With this change, the canary tests added in d76e45c transition from
FAIL->PASS on CPU EP while CUDA EP guards continue to pass:
- AttentionTest.Attention_Unfused_Softcap_NegInfMask_PoisonV_CPU
- AttentionTest.Attention_Unfused_Softcap_NegInfMask_PoisonV_CUDA
- TestONNXAttentionCPUSoftcapMaskOrdering.test_cpu_attention_softcap_*

Verification: 58/58 AttentionTest.* + 19/19 GQATest.* + 217/218 Python
transformers/test_onnx_attention (the one failure is an unrelated
pre-existing tensorscatter validation test, unmodified by this PR).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants