Skip to content

[MPS] Fix SDPA output shape when value head dim differs#176843

Closed
hvaara wants to merge 1 commit intopytorch:mainfrom
hvaara:mps-sdpa-ev-shape-fix
Closed

[MPS] Fix SDPA output shape when value head dim differs#176843
hvaara wants to merge 1 commit intopytorch:mainfrom
hvaara:mps-sdpa-ev-shape-fix

Conversation

@hvaara
Copy link
Copy Markdown
Contributor

@hvaara hvaara commented Mar 8, 2026

This fixes MPS SDPA output shape for cases where value.size(-1) != query.size(-1), so output now follows (..., L, Ev) as expected. I also added guards in Metal kernel paths that assume equal qkv head dims.

Added the updated meta shape inference for the sdpa_general_mps path which seems to have been left out initially.

Added regression coverage in test/test_transformers.py covering the shape semantics, and a similar one in test/test_mps.py that also checks for numerical parity with CPU.

Fixes #176767

@hvaara hvaara requested a review from malfet as a code owner March 8, 2026 20:22
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 8, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176843

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cf60249 with merge base 7643509 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@malfet
Copy link
Copy Markdown
Contributor

malfet commented Mar 9, 2026

@pytorchbot merge -f "Lint + MPS is green"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@hvaara hvaara deleted the mps-sdpa-ev-shape-fix branch March 9, 2026 21:55
pytorchmergebot pushed a commit that referenced this pull request Mar 17, 2026
Fix proposed by @mergennachin in #177603. The issue was introduced in #176843.

Remove data-dependent branching in the MPS SDPA meta kernel so export supports dynamic seq.

Update meta-dispatch test to compare only the first output and add an export regression test.

@angelayi, you wrote the original meta registration and tests in #159695. Does this LGTY?

Fixes #177603
Pull Request resolved: #177620
Approved by: https://github.com/malfet
mergennachin added a commit that referenced this pull request Mar 17, 2026
Summary

Add float mask test coverage to test_sdpa_export_dynamic_seq_len, complementing the fix in #177620.

The original regression (#177603, introduced by #176843) was triggered in practice by float attention masks — Metal SDPA requires mask dtype to match Q/K/V dtype, so real models use float masks, not bool. The existing test
only covered bool masks. This also verifies across the <= 8 / > 8 seq_len boundary that was the branching condition in the buggy meta kernel.

Test plan

- python test/test_mps.py TestSDPA.test_sdpa_export_dynamic_seq_len
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
This fixes MPS SDPA output shape for cases where `value.size(-1) != query.size(-1)`, so output now follows `(..., L, Ev)` as expected. I also added guards in Metal kernel paths that assume equal qkv head dims.

Added the updated meta shape inference for the `sdpa_general_mps` path which seems to have been left out initially.

Added regression coverage in `test/test_transformers.py` covering the shape semantics, and a similar one in `test/test_mps.py` that also checks for numerical parity with CPU.

Fixes pytorch#176767
Pull Request resolved: pytorch#176843
Approved by: https://github.com/malfet
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
Fix proposed by @mergennachin in pytorch#177603. The issue was introduced in pytorch#176843.

Remove data-dependent branching in the MPS SDPA meta kernel so export supports dynamic seq.

Update meta-dispatch test to compare only the first output and add an export regression test.

@angelayi, you wrote the original meta registration and tests in pytorch#159695. Does this LGTY?

Fixes pytorch#177603
Pull Request resolved: pytorch#177620
Approved by: https://github.com/malfet
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
Fix proposed by @mergennachin in pytorch#177603. The issue was introduced in pytorch#176843.

Remove data-dependent branching in the MPS SDPA meta kernel so export supports dynamic seq.

Update meta-dispatch test to compare only the first output and add an export regression test.

@angelayi, you wrote the original meta registration and tests in pytorch#159695. Does this LGTY?

Fixes pytorch#177603
Pull Request resolved: pytorch#177620
Approved by: https://github.com/malfet
Jah-yee pushed a commit to Jah-yee/transformers-jy that referenced this pull request Apr 16, 2026
… dim

Workaround for PyTorch < 2.12 bug (pytorch/pytorch#176767, pytorch/pytorch#176843)
where scaled_dot_product_attention on MPS produces incorrect output when
value head dim != query head dim.

DeepSeek models (MQA) are affected as they have different qk and v head dims.
The fix pads v to match q's head dim before SDPA, then truncates output
back to the original v size.

Fixes huggingface#44554
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) Merged open source release notes: mps Release notes category topic: bug fixes topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MPS: scaled_dot_product_attention returns wrong output shape when value dim != query/key dim

4 participants