Skip to content

fix(moe): Handle dtype mismatch in torch._grouped_mm with autocast#43839

Merged
ArthurZucker merged 2 commits intohuggingface:mainfrom
Mr-Neutr0n:fix/phimoe-autocast-dtype
Feb 11, 2026
Merged

fix(moe): Handle dtype mismatch in torch._grouped_mm with autocast#43839
ArthurZucker merged 2 commits intohuggingface:mainfrom
Mr-Neutr0n:fix/phimoe-autocast-dtype

Conversation

@Mr-Neutr0n
Copy link
Copy Markdown
Contributor

What does this PR do?

Fixes a RuntimeError: expected mat1 and mat2 to have the same dtype error when using torch.autocast with MoE models like microsoft/Phi-tiny-MoE-instruct.

Problem

torch._grouped_mm is not autocast-enabled, meaning it doesn't automatically cast input tensors to match when torch.autocast is active. This causes a dtype mismatch when:

  1. Model weights are in bfloat16 (e.g., loaded with dtype="bfloat16")
  2. Input hidden states may be float32 at certain computation stages
  3. The grouped GEMM operation fails with dtype mismatch error

Error trace:

RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16

Solution

Add explicit dtype casting in _grouped_linear() to ensure input and weight tensors have the same dtype before calling torch._grouped_mm:

if input.dtype != weight.dtype:
    input = input.to(weight.dtype)

This is the same pattern used in other low-level ops that aren't autocast-enabled.

Note: _batched_linear doesn't need this fix because torch.bmm IS autocast-enabled.

Fixes #43828

torch._grouped_mm is not autocast-enabled, so when using torch.autocast
with MoE models like Phi-tiny-MoE, the input tensor may have a different
dtype than the weights, causing RuntimeError.

Add explicit dtype casting in _grouped_linear to ensure input and weight
tensors have the same dtype before calling torch._grouped_mm.

Fixes huggingface#43828
@IlyasMoutawwakil
Copy link
Copy Markdown
Member

hi ! wouldn't it make more sense to disable autocast around grouped_mm ?

@hxrikp1729
Copy link
Copy Markdown

@IlyasMoutawwakil yeah, that's a fair point. Disabling autocast around grouped_mm would be cleaner since it addresses the root cause directly rather than adding a manual cast. The current approach works but it does mean we're doing an explicit dtype cast that could technically diverge from what autocast would have chosen.

Would something like wrapping the torch._grouped_mm calls with torch.amp.autocast(enabled=False) be what you had in mind? Happy to rework the PR in that direction if so.

@IlyasMoutawwakil
Copy link
Copy Markdown
Member

yes let's do that, and open an issue in https://github.com/pytorch/pytorch/issues
@ArthurZucker wdyt ? i saw a couple PRs trying to fix this by casting inputs/outputs

Replace manual input.to(weight.dtype) cast with torch.amp.autocast(enabled=False)
wrapper around torch._grouped_mm calls. This is cleaner since it addresses the
root cause (grouped_mm not being autocast-aware) rather than guessing the target
dtype.

Ref: pytorch/pytorch#174763
@Mr-Neutr0n
Copy link
Copy Markdown
Contributor Author

@IlyasMoutawwakil Done — reworked to wrap the torch._grouped_mm calls with torch.amp.autocast(device_type=..., enabled=False) instead of the manual dtype cast. Much cleaner.

Also opened the upstream PyTorch issue: pytorch/pytorch#174763

Copy link
Copy Markdown
Member

@IlyasMoutawwakil IlyasMoutawwakil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! did you verify that it solves the issue on the relevant model ?

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#43833 came first, but this fix is more elegant so let's go with this one!

@ArthurZucker ArthurZucker merged commit 007bb8c into huggingface:main Feb 11, 2026
25 checks passed
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Mr-Neutr0n
Copy link
Copy Markdown
Contributor Author

Thank you for the kind words @ArthurZucker

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

With torch.autocast, Phi-tiny-MoE-instruct raises an dtype mismatch error

5 participants