diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 23db95815c54..f94cba1ede4f 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -184,12 +184,15 @@ def _grouped_linear( Returns: `torch.Tensor`: Output tensor of shape (S, output_dim). """ - if is_transposed: - # (S, input_dim) @ grouped (num_experts, input_dim, output_dim) -> (S, output_dim) - out = torch._grouped_mm(input, weight, offs=offs) - else: - # (S, input_dim) @ grouped (num_experts, output_dim, input_dim).T -> (S, output_dim) - out = torch._grouped_mm(input, weight.transpose(-2, -1), offs=offs) + # torch._grouped_mm is not autocast-enabled, so we disable autocast to avoid dtype mismatch. + # See: https://github.com/pytorch/pytorch/issues/174763 + with torch.amp.autocast(device_type=input.device.type, enabled=False): + if is_transposed: + # (S, input_dim) @ grouped (num_experts, input_dim, output_dim) -> (S, output_dim) + out = torch._grouped_mm(input, weight, offs=offs) + else: + # (S, input_dim) @ grouped (num_experts, output_dim, input_dim).T -> (S, output_dim) + out = torch._grouped_mm(input, weight.transpose(-2, -1), offs=offs) if bias is not None: # We should be able to pass bias to the grouped_mm call, but it's not yet supported.