fix(moe): Handle dtype mismatch in torch._grouped_mm with autocast#43839
fix(moe): Handle dtype mismatch in torch._grouped_mm with autocast#43839ArthurZucker merged 2 commits intohuggingface:mainfrom
Conversation
torch._grouped_mm is not autocast-enabled, so when using torch.autocast with MoE models like Phi-tiny-MoE, the input tensor may have a different dtype than the weights, causing RuntimeError. Add explicit dtype casting in _grouped_linear to ensure input and weight tensors have the same dtype before calling torch._grouped_mm. Fixes huggingface#43828
|
hi ! wouldn't it make more sense to disable autocast around grouped_mm ? |
|
@IlyasMoutawwakil yeah, that's a fair point. Disabling autocast around Would something like wrapping the |
|
yes let's do that, and open an issue in https://github.com/pytorch/pytorch/issues |
Replace manual input.to(weight.dtype) cast with torch.amp.autocast(enabled=False) wrapper around torch._grouped_mm calls. This is cleaner since it addresses the root cause (grouped_mm not being autocast-aware) rather than guessing the target dtype. Ref: pytorch/pytorch#174763
|
@IlyasMoutawwakil Done — reworked to wrap the Also opened the upstream PyTorch issue: pytorch/pytorch#174763 |
IlyasMoutawwakil
left a comment
There was a problem hiding this comment.
LGTM ! did you verify that it solves the issue on the relevant model ?
ArthurZucker
left a comment
There was a problem hiding this comment.
#43833 came first, but this fix is more elegant so let's go with this one!
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Thank you for the kind words @ArthurZucker |
What does this PR do?
Fixes a
RuntimeError: expected mat1 and mat2 to have the same dtypeerror when usingtorch.autocastwith MoE models likemicrosoft/Phi-tiny-MoE-instruct.Problem
torch._grouped_mmis not autocast-enabled, meaning it doesn't automatically cast input tensors to match whentorch.autocastis active. This causes a dtype mismatch when:dtype="bfloat16")Error trace:
Solution
Add explicit dtype casting in
_grouped_linear()to ensure input and weight tensors have the same dtype before callingtorch._grouped_mm:This is the same pattern used in other low-level ops that aren't autocast-enabled.
Note:
_batched_lineardoesn't need this fix becausetorch.bmmIS autocast-enabled.Fixes #43828