Skip to content

fix: ensure dtype consistency in grouped_mm under autocast#43833

Open
nulone wants to merge 2 commits intohuggingface:mainfrom
nulone:fix/43828-phimoe-dtype-autocast
Open

fix: ensure dtype consistency in grouped_mm under autocast#43833
nulone wants to merge 2 commits intohuggingface:mainfrom
nulone:fix/43828-phimoe-dtype-autocast

Conversation

@nulone
Copy link
Copy Markdown

@nulone nulone commented Feb 8, 2026

Fixes #43828

What does this PR do?

torch._grouped_mm is not registered for autocast. Under torch.autocast, LayerNorm outputs float32 while model weights stay bfloat16, causing RuntimeError: "expected mat1 and mat2 to have same dtype".

Fix

Cast input to weight.dtype before calling _grouped_mm in src/transformers/integrations/moe.py.

Impact

Affects all MoE models using grouped_mm under autocast (Mixtral, Qwen3 MoE, DeepSeek, PhiMoE, etc.)

Before submitting

Note: No local GPU access — relying on CI for verification.

torch._grouped_mm is not registered for autocast, causing dtype mismatch
when LayerNorm outputs float32 but weights are bfloat16.

Fixes huggingface#43828
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! does not look too bad, just unsure, should we cast back post op?

@nulone
Copy link
Copy Markdown
Author

nulone commented Feb 11, 2026

No need — the result is already cast back to hidden_states.dtype at line 273:

https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/moe.py#L273

So the dtype flow is: input (float32) → cast to weight.dtype (bf16) → _grouped_mm → cast back to hidden_states.dtype (float32) at the end.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

With torch.autocast, Phi-tiny-MoE-instruct raises an dtype mismatch error

3 participants