From 1d3bb2d1eb626b509a75f3d0a18d66c5f4c98679 Mon Sep 17 00:00:00 2001 From: Mr-Neutr0n <64578610+Mr-Neutr0n@users.noreply.github.com> Date: Sun, 8 Feb 2026 17:50:59 +0530 Subject: [PATCH 1/2] fix(moe): Handle dtype mismatch in grouped_mm with autocast torch._grouped_mm is not autocast-enabled, so when using torch.autocast with MoE models like Phi-tiny-MoE, the input tensor may have a different dtype than the weights, causing RuntimeError. Add explicit dtype casting in _grouped_linear to ensure input and weight tensors have the same dtype before calling torch._grouped_mm. Fixes #43828 --- src/transformers/integrations/moe.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 23db95815c54..68df7b8a15cc 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -184,6 +184,11 @@ def _grouped_linear( Returns: `torch.Tensor`: Output tensor of shape (S, output_dim). """ + # torch._grouped_mm is not autocast-enabled, so we need to ensure dtype compatibility manually. + # Cast input to match weight dtype to avoid dtype mismatch errors. + if input.dtype != weight.dtype: + input = input.to(weight.dtype) + if is_transposed: # (S, input_dim) @ grouped (num_experts, input_dim, output_dim) -> (S, output_dim) out = torch._grouped_mm(input, weight, offs=offs) From 60b7a73b85ff42e34db1221322109c6d1cc51aac Mon Sep 17 00:00:00 2001 From: Mr-Neutr0n <64578610+Mr-Neutr0n@users.noreply.github.com> Date: Wed, 11 Feb 2026 15:12:46 +0530 Subject: [PATCH 2/2] refactor: use autocast(enabled=False) instead of manual dtype cast Replace manual input.to(weight.dtype) cast with torch.amp.autocast(enabled=False) wrapper around torch._grouped_mm calls. This is cleaner since it addresses the root cause (grouped_mm not being autocast-aware) rather than guessing the target dtype. Ref: https://github.com/pytorch/pytorch/issues/174763 --- src/transformers/integrations/moe.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 68df7b8a15cc..f94cba1ede4f 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -184,17 +184,15 @@ def _grouped_linear( Returns: `torch.Tensor`: Output tensor of shape (S, output_dim). """ - # torch._grouped_mm is not autocast-enabled, so we need to ensure dtype compatibility manually. - # Cast input to match weight dtype to avoid dtype mismatch errors. - if input.dtype != weight.dtype: - input = input.to(weight.dtype) - - if is_transposed: - # (S, input_dim) @ grouped (num_experts, input_dim, output_dim) -> (S, output_dim) - out = torch._grouped_mm(input, weight, offs=offs) - else: - # (S, input_dim) @ grouped (num_experts, output_dim, input_dim).T -> (S, output_dim) - out = torch._grouped_mm(input, weight.transpose(-2, -1), offs=offs) + # torch._grouped_mm is not autocast-enabled, so we disable autocast to avoid dtype mismatch. + # See: https://github.com/pytorch/pytorch/issues/174763 + with torch.amp.autocast(device_type=input.device.type, enabled=False): + if is_transposed: + # (S, input_dim) @ grouped (num_experts, input_dim, output_dim) -> (S, output_dim) + out = torch._grouped_mm(input, weight, offs=offs) + else: + # (S, input_dim) @ grouped (num_experts, output_dim, input_dim).T -> (S, output_dim) + out = torch._grouped_mm(input, weight.transpose(-2, -1), offs=offs) if bias is not None: # We should be able to pass bias to the grouped_mm call, but it's not yet supported.