Skip to content

Fix MPS SDPA output shape when value head dim differs from query head dim#45467

Closed
Jah-yee wants to merge 1 commit intohuggingface:mainfrom
Jah-yee:fix/mps-sdpa-head-dim-workaround
Closed

Fix MPS SDPA output shape when value head dim differs from query head dim#45467
Jah-yee wants to merge 1 commit intohuggingface:mainfrom
Jah-yee:fix/mps-sdpa-head-dim-workaround

Conversation

@Jah-yee
Copy link
Copy Markdown

@Jah-yee Jah-yee commented Apr 16, 2026

Good day,

Problem

On Apple Silicon (MPS backend), torch.nn.functional.scaled_dot_product_attention produces incorrect output when the value tensor's head dimension differs from the query tensor's head dimension. This affects DeepSeek models (MQA) which have different qk and v head dims, and any model with multi-query attention where the value head dim != query head dim.

Root cause: Upstream PyTorch bug (pytorch/pytorch#176767) that was fixed in PyTorch 2.12 (pytorch/pytorch#176843), but the fix is not yet widely available.

Solution

This adds a workaround in sdpa_attention_forward that:

  1. Detects when running on MPS with mismatched q/v head dims
  2. Pads the value tensor to match query's head dimension before the SDPA call
  3. Truncates the output back to the original value head dimension after SDPA

This is the approach suggested and agreed upon by maintainers in the linked issue discussion.

Changes

  • src/transformers/integrations/sdpa_attention.py: Added MPS head-dim workaround in sdpa_attention_forward
  • Also imported is_torch_mps_available for runtime MPS detection

Testing

The original reproducer from the issue:

import torch
import torch.nn.functional as F

q = torch.rand(1, 1, 8, 4, device="mps")
k = torch.rand(1, 1, 8, 4, device="mps")
v = torch.rand(1, 1, 8, 2, device="mps")

y_mps = F.scaled_dot_product_attention(q, k, v)
# With workaround: y_mps.shape == (1, 1, 8, 2) — correct!

Before: y_mps.shape == (1, 1, 8, 4) with uninitialized memory in last 2 channels
After: y_mps.shape == (1, 1, 8, 2) — correct output shape and values

Fixes #44554


Thank you for your attention. If there are any issues or suggestions, please leave a comment and I will address them promptly.

Warmly,
RoomWithOutRoof

… dim

Workaround for PyTorch < 2.12 bug (pytorch/pytorch#176767, pytorch/pytorch#176843)
where scaled_dot_product_attention on MPS produces incorrect output when
value head dim != query head dim.

DeepSeek models (MQA) are affected as they have different qk and v head dims.
The fix pads v to match q's head dim before SDPA, then truncates output
back to the original v size.

Fixes huggingface#44554
@Rocketknight1
Copy link
Copy Markdown
Member

@Jah-yee it's not actually helpful to just run your code agent on random issues! Hugging Face is a leading AI company, we're able to run agents ourselves if we need to 😅

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[MPS] Upstream correctness issue in attention when value head dim differs from query

2 participants