From a7abb1930000c98d17e5c8c0cf96108e90ce9c54 Mon Sep 17 00:00:00 2001 From: Matrix Yao Date: Wed, 21 May 2025 22:34:05 +0000 Subject: [PATCH 1/9] siwtch to device agnostic autocast in nemotron to align xpu behavior w/ cuda Signed-off-by: Matrix Yao --- .../models/nemotron/modeling_nemotron.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 3e1ce01041b5..8bf813da7dcf 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -51,11 +51,11 @@ logger = logging.get_logger(__name__) -def _cast_if_autocast_enabled(*args): +def _cast_if_autocast_enabled(device_type, *args): if not torch.is_autocast_enabled(): return args else: - return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype()) + return torch.amp.autocast_mode._cast(args, torch.get_autocast_dtype(device_type)) class NemotronLayerNorm1P(nn.LayerNorm): @@ -70,9 +70,10 @@ def __init__( ): super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype) - def forward(self, input: Tensor) -> Tensor: - args = _cast_if_autocast_enabled(input, self.normalized_shape, self.weight + 1, self.bias, self.eps) - with torch.amp.autocast(input.device.type, enabled=False): + def forward(self, input: Tensor) -> Tensor: + device_type = input.device.type if isinstance(input.device.type, str) and input.device.type != "mps" else "cpu" + args = _cast_if_autocast_enabled(device_type, input, self.normalized_shape, self.weight + 1, self.bias, self.eps) + with torch.autocast(device_type=input.device.type, enabled=False): return F.layer_norm(*args) @@ -344,9 +345,10 @@ def forward( # in fp32. (NemotronRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = query_states.device.type if isinstance(query_states.device.type, str) and query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = torch.get_autocast_dtype(device_type) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype From 1130d3b822166ef3ebd2a7a0081bee80dd4ae816 Mon Sep 17 00:00:00 2001 From: Matrix Yao Date: Thu, 22 May 2025 07:35:56 +0000 Subject: [PATCH 2/9] fix issue Signed-off-by: Matrix Yao --- src/transformers/models/nemotron/modeling_nemotron.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 8bf813da7dcf..b21a06dc5b7f 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -55,7 +55,7 @@ def _cast_if_autocast_enabled(device_type, *args): if not torch.is_autocast_enabled(): return args else: - return torch.amp.autocast_mode._cast(args, torch.get_autocast_dtype(device_type)) + return torch.amp.autocast_mode._cast(args, device_type, torch.get_autocast_dtype(device_type)) class NemotronLayerNorm1P(nn.LayerNorm): @@ -70,7 +70,7 @@ def __init__( ): super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype) - def forward(self, input: Tensor) -> Tensor: + def forward(self, input: Tensor) -> Tensor: device_type = input.device.type if isinstance(input.device.type, str) and input.device.type != "mps" else "cpu" args = _cast_if_autocast_enabled(device_type, input, self.normalized_shape, self.weight + 1, self.bias, self.eps) with torch.autocast(device_type=input.device.type, enabled=False): From ded4b4ed5e0c17bb93b53c02842de85cc9d36b87 Mon Sep 17 00:00:00 2001 From: Matrix Yao Date: Thu, 22 May 2025 07:36:23 +0000 Subject: [PATCH 3/9] fix style Signed-off-by: Matrix Yao --- src/transformers/models/nemotron/modeling_nemotron.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index b21a06dc5b7f..53b7d8125b26 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -72,7 +72,9 @@ def __init__( def forward(self, input: Tensor) -> Tensor: device_type = input.device.type if isinstance(input.device.type, str) and input.device.type != "mps" else "cpu" - args = _cast_if_autocast_enabled(device_type, input, self.normalized_shape, self.weight + 1, self.bias, self.eps) + args = _cast_if_autocast_enabled( + device_type, input, self.normalized_shape, self.weight + 1, self.bias, self.eps + ) with torch.autocast(device_type=input.device.type, enabled=False): return F.layer_norm(*args) @@ -345,7 +347,11 @@ def forward( # in fp32. (NemotronRMSNorm handles it correctly) input_dtype = query_states.dtype - device_type = query_states.device.type if isinstance(query_states.device.type, str) and query_states.device.type != "mps" else "cpu" + device_type = ( + query_states.device.type + if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + else "cpu" + ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_dtype(device_type) From 18a3bd88fa5565a79894f31db6760586885beda4 Mon Sep 17 00:00:00 2001 From: Matrix Yao Date: Fri, 23 May 2025 02:36:44 +0000 Subject: [PATCH 4/9] use torch.cast as other modeling code for decision_transformer&gpt2&imagegpt Signed-off-by: Matrix Yao --- .../decision_transformer/modeling_decision_transformer.py | 2 +- src/transformers/models/gpt2/modeling_gpt2.py | 2 +- src/transformers/models/imagegpt/modeling_imagegpt.py | 3 +-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index c577d6f17c65..b864bc2009fd 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -219,7 +219,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea scale_factor /= float(self.layer_idx + 1) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) - with torch.amp.autocast(query.device.type, enabled=False): + with torch.autocast(query.device.type, enabled=False): q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 4b3853b43691..8de8eb9001f8 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -229,7 +229,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea scale_factor /= float(self.layer_idx + 1) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) - with torch.amp.autocast(query.device.type, enabled=False): + with torch.autocast(query.device.type, enabled=False): q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 68f6a04c5a2a..0614d0a95bb7 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -22,7 +22,6 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.cuda.amp import autocast from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN @@ -280,7 +279,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea scale_factor /= float(self.layer_idx + 1) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) - with autocast(enabled=False): + with torch.autocast(query.device.type, enabled=False): q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) From 2b8f5d4714c94180b2a6a04d3302e551fc16252f Mon Sep 17 00:00:00 2001 From: Matrix Yao Date: Fri, 23 May 2025 03:01:43 +0000 Subject: [PATCH 5/9] refine Signed-off-by: Matrix Yao --- .../models/nemotron/modeling_nemotron.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 5831f3e84d32..10d85465a302 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -55,7 +55,13 @@ def _cast_if_autocast_enabled(device_type, *args): if not torch.is_autocast_enabled(): return args else: - return torch.amp.autocast_mode._cast(args, device_type, torch.get_autocast_dtype(device_type)) + # NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4 + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) + return torch.amp.autocast_mode._cast(args, device_type, target_dtype) class NemotronLayerNorm1P(nn.LayerNorm): @@ -354,7 +360,12 @@ def forward( ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_dtype(device_type) + # NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4 + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype From 74bf42784d401ad4a00a58a7370efd1f033bd05b Mon Sep 17 00:00:00 2001 From: Matrix YAO Date: Thu, 29 May 2025 05:53:19 +0000 Subject: [PATCH 6/9] update get_autocast_gpu_dtype to device agnostic one Signed-off-by: Matrix YAO --- src/transformers/models/dbrx/modeling_dbrx.py | 11 +++++++- .../models/diffllama/modeling_diffllama.py | 7 ++++- .../models/diffllama/modular_diffllama.py | 11 +++++++- .../models/distilbert/modeling_distilbert.py | 11 +++++++- src/transformers/models/esm/modeling_esm.py | 11 +++++++- .../models/esm/modeling_esmfold.py | 27 ++++++++++++++----- .../models/falcon/modeling_falcon.py | 11 +++++++- .../gpt_bigcode/modeling_gpt_bigcode.py | 11 +++++++- .../models/gpt_neo/modeling_gpt_neo.py | 11 +++++++- src/transformers/models/gptj/modeling_gptj.py | 11 +++++++- .../models/jamba/modeling_jamba.py | 11 +++++++- .../models/jetmoe/modeling_jetmoe.py | 11 +++++++- src/transformers/models/mimi/modeling_mimi.py | 11 +++++++- .../models/moshi/modeling_moshi.py | 11 +++++++- .../models/olmoe/modeling_olmoe.py | 11 +++++++- .../models/phimoe/modeling_phimoe.py | 11 +++++++- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 11 +++++++- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 11 +++++++- .../models/qwen2_moe/modeling_qwen2_moe.py | 11 +++++++- .../models/qwen2_vl/modeling_qwen2_vl.py | 7 ++++- tests/models/mamba2/test_modeling_mamba2.py | 2 +- 21 files changed, 204 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 0a530e87ae1b..5caa3f11c442 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -387,9 +387,18 @@ def forward( # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = ( + query_states.device.type + if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + else "cpu" + ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 68aa54180caa..92cb8a1e431c 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -306,9 +306,14 @@ def forward( # in fp32. (DiffLlamaRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = ( + query_states.device.type + if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + else "cpu" + ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = torch.get_autocast_dtype(device_type) if hasattr(torch, "get_autocast_dtype") else torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index b772a9f04d5f..46618ef8abf1 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -239,9 +239,18 @@ def forward( # in fp32. (DiffLlamaRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = ( + query_states.device.type + if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + else "cpu" + ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 1a241884145d..5689e7b71b5f 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -289,9 +289,18 @@ def reshape(x: torch.Tensor) -> torch.Tensor: # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) + device_type = ( + query_states.device.type + if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + else "cpu" + ) if query_states.dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index da62ecbea2a4..0b4341ad3589 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -458,9 +458,18 @@ def forward( # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. input_dtype = query_layer.dtype + device_type = ( + query_layer.device.type + if isinstance(query_layer.device.type, str) and query_layer.device.type != "mps" + else "cpu" + ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/esm/modeling_esmfold.py b/src/transformers/models/esm/modeling_esmfold.py index 203aa9a69a39..408b2e26961f 100644 --- a/src/transformers/models/esm/modeling_esmfold.py +++ b/src/transformers/models/esm/modeling_esmfold.py @@ -132,9 +132,14 @@ class EsmForProteinFoldingOutput(ModelOutput): max_predicted_aligned_error: Optional[torch.FloatTensor] = None -def is_fp16_enabled(): +def is_fp16_enabled(device_type): # Autocast world - fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16 + autocast_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) + fp16_enabled = autocast_dtype == torch.float16 fp16_enabled = fp16_enabled and torch.is_autocast_enabled() return fp16_enabled @@ -884,8 +889,13 @@ def forward( b = b * self.sigmoid(self.linear_b_g(z)) b = b * self.linear_b_p(z) - if is_fp16_enabled(): - with torch.cuda.amp.autocast(enabled=False): + device_type = ( + a.device.type + if isinstance(a.device.type, str) and a.device.type != "mps" + else "cpu" + ) + if is_fp16_enabled(device_type): + with torch.autocast(device_type=device_type, enabled=False): x = self._combine_projections(a.float(), b.float()) else: x = self._combine_projections(a, b) @@ -1498,8 +1508,13 @@ def forward( z[0] = z[0].cpu() # [*, H, N_res, N_res] - if is_fp16_enabled(): - with torch.cuda.amp.autocast(enabled=False): + device_type = ( + q.device.type + if isinstance(q.device.type, str) and q.device.type != "mps" + else "cpu" + ) + if is_fp16_enabled(device_type): + with torch.autocast(device_type=device_type, enabled=False): a = torch.matmul( permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden] permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res] diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index df87d36242e0..656df5d32c16 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -488,9 +488,18 @@ def forward( # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_layer.dtype + device_type = ( + query_layer.device.type + if isinstance(query_layer.device.type, str) and query_layer.device.type != "mps" + else "cpu" + ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 7f40aabefb5b..269dabe8a587 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -343,9 +343,18 @@ def forward( # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query.dtype + device_type = ( + query.device.type + if isinstance(query.device.type, str) and query.device.type != "mps" + else "cpu" + ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 95de9e82d5ec..f9507803a298 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -323,9 +323,18 @@ def forward( # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) + device_type = ( + query.device.type + if isinstance(query.device.type, str) and query.device.type != "mps" + else "cpu" + ) if query.dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 093daaef193f..a5052f3868b2 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -355,9 +355,18 @@ def forward( # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query.dtype + device_type = ( + query.device.type + if isinstance(query.device.type, str) and query.device.type != "mps" + else "cpu" + ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index d60190161ef6..df7fc1f02f25 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -412,9 +412,18 @@ def forward( # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype + device_type = ( + query_states.device.type + if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + else "cpu" + ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 788b2066b5dc..531f954d19c7 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -710,9 +710,18 @@ def forward( # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = ( + query_states.device.type + if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + else "cpu" + ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 4d7b92979ab1..ef23acfd5dfd 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -661,9 +661,18 @@ def forward( # in fp32. (MimiRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = ( + query_states.device.type + if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + else "cpu" + ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 7e71eb2ce200..fd99ba3fe182 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -594,9 +594,18 @@ def forward( # in fp32. (MoshiRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = ( + query_states.device.type + if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + else "cpu" + ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 88f884dc2e43..6d9c9ac2b064 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -421,9 +421,18 @@ def forward( # in fp32. (OlmoeRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = ( + query_states.device.type + if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + else "cpu" + ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index e81d38e2d88d..982c45e25658 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -369,9 +369,18 @@ def forward( # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype + device_type = ( + query_states.device.type + if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + else "cpu" + ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 84e553465337..4a8159dcaf20 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -1615,9 +1615,18 @@ def forward( # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype + device_type = ( + query_states.device.type + if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + else "cpu" + ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index f2c71f991fe9..5c923a86ef1c 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -843,9 +843,18 @@ def forward( # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype + device_type = ( + query_states.device.type + if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + else "cpu" + ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index a529dcdb559c..04a3e4aa75d1 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -418,9 +418,18 @@ def forward( # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype + device_type = ( + query_states.device.type + if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + else "cpu" + ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index b9f166074c88..a3920e879da5 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -678,9 +678,14 @@ def forward( # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype + device_type = ( + query_states.device.type + if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + else "cpu" + ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = torch.get_autocast_dtype(device_type) if hasattr(torch, "get_autocast_dtype") else torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index 5777053923a3..dfa8bca69ef4 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -462,7 +462,7 @@ def test_mamba2_mixer_train_vs_eval_equivalence(self): config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1) torch.manual_seed(42) - with torch.amp.autocast(device_type=torch_device, dtype=dtype): + with torch.autocast(device_type=torch_device, dtype=dtype): with torch.no_grad(): mixer = Mamba2Mixer(config, layer_idx=0).to(torch_device) hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device=torch_device) From 7241eb37bbc703360ce09acfbfb10b1d6e368284 Mon Sep 17 00:00:00 2001 From: Matrix YAO Date: Thu, 29 May 2025 05:55:51 +0000 Subject: [PATCH 7/9] fix style Signed-off-by: Matrix YAO --- .../models/diffllama/modeling_diffllama.py | 6 +++++- .../models/esm/modeling_esmfold.py | 20 ++++++------------- .../gpt_bigcode/modeling_gpt_bigcode.py | 6 +----- .../models/gpt_neo/modeling_gpt_neo.py | 8 ++------ src/transformers/models/gptj/modeling_gptj.py | 6 +----- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 +- .../models/qwen2_vl/modeling_qwen2_vl.py | 6 +++++- 7 files changed, 21 insertions(+), 33 deletions(-) diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 92cb8a1e431c..e509a4a605a6 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -313,7 +313,11 @@ def forward( ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_dtype(device_type) if hasattr(torch, "get_autocast_dtype") else torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/esm/modeling_esmfold.py b/src/transformers/models/esm/modeling_esmfold.py index 408b2e26961f..e42060408068 100644 --- a/src/transformers/models/esm/modeling_esmfold.py +++ b/src/transformers/models/esm/modeling_esmfold.py @@ -135,10 +135,10 @@ class EsmForProteinFoldingOutput(ModelOutput): def is_fp16_enabled(device_type): # Autocast world autocast_dtype = ( - torch.get_autocast_dtype(device_type) - if hasattr(torch, "get_autocast_dtype") - else torch.get_autocast_gpu_dtype() - ) + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) fp16_enabled = autocast_dtype == torch.float16 fp16_enabled = fp16_enabled and torch.is_autocast_enabled() @@ -889,11 +889,7 @@ def forward( b = b * self.sigmoid(self.linear_b_g(z)) b = b * self.linear_b_p(z) - device_type = ( - a.device.type - if isinstance(a.device.type, str) and a.device.type != "mps" - else "cpu" - ) + device_type = a.device.type if isinstance(a.device.type, str) and a.device.type != "mps" else "cpu" if is_fp16_enabled(device_type): with torch.autocast(device_type=device_type, enabled=False): x = self._combine_projections(a.float(), b.float()) @@ -1508,11 +1504,7 @@ def forward( z[0] = z[0].cpu() # [*, H, N_res, N_res] - device_type = ( - q.device.type - if isinstance(q.device.type, str) and q.device.type != "mps" - else "cpu" - ) + device_type = q.device.type if isinstance(q.device.type, str) and q.device.type != "mps" else "cpu" if is_fp16_enabled(device_type): with torch.autocast(device_type=device_type, enabled=False): a = torch.matmul( diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 269dabe8a587..ca5adafd7d5b 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -343,11 +343,7 @@ def forward( # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query.dtype - device_type = ( - query.device.type - if isinstance(query.device.type, str) and query.device.type != "mps" - else "cpu" - ) + device_type = query.device.type if isinstance(query.device.type, str) and query.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = ( diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index f9507803a298..79b08facc634 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -323,14 +323,10 @@ def forward( # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) - device_type = ( - query.device.type - if isinstance(query.device.type, str) and query.device.type != "mps" - else "cpu" - ) + device_type = query.device.type if isinstance(query.device.type, str) and query.device.type != "mps" else "cpu" if query.dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = ( + target_dtype = ( torch.get_autocast_dtype(device_type) if hasattr(torch, "get_autocast_dtype") else torch.get_autocast_gpu_dtype() diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index a5052f3868b2..6665150f29d1 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -355,11 +355,7 @@ def forward( # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query.dtype - device_type = ( - query.device.type - if isinstance(query.device.type, str) and query.device.type != "mps" - else "cpu" - ) + device_type = query.device.type if isinstance(query.device.type, str) and query.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = ( diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 5c923a86ef1c..a96e10b9a65b 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -850,7 +850,7 @@ def forward( ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = ( + target_dtype = ( torch.get_autocast_dtype(device_type) if hasattr(torch, "get_autocast_dtype") else torch.get_autocast_gpu_dtype() diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index a3920e879da5..856be28771bd 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -685,7 +685,11 @@ def forward( ) if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_dtype(device_type) if hasattr(torch, "get_autocast_dtype") else torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype From 2f2cc646ee9ca1726c4117fd03e4f173a948f000 Mon Sep 17 00:00:00 2001 From: YAO Matrix Date: Wed, 18 Jun 2025 00:03:52 +0000 Subject: [PATCH 8/9] fix comments Signed-off-by: YAO Matrix --- src/transformers/models/chameleon/modeling_chameleon.py | 2 +- src/transformers/models/dbrx/modeling_dbrx.py | 4 ++-- src/transformers/models/diffllama/modeling_diffllama.py | 2 +- src/transformers/models/diffllama/modular_diffllama.py | 2 +- src/transformers/models/distilbert/modeling_distilbert.py | 2 +- src/transformers/models/esm/modeling_esm.py | 2 +- src/transformers/models/esm/modeling_esmfold.py | 4 ++-- src/transformers/models/falcon/modeling_falcon.py | 2 +- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 2 +- src/transformers/models/gpt_neo/modeling_gpt_neo.py | 2 +- src/transformers/models/gptj/modeling_gptj.py | 2 +- src/transformers/models/jamba/modeling_jamba.py | 2 +- src/transformers/models/jetmoe/modeling_jetmoe.py | 2 +- src/transformers/models/mimi/modeling_mimi.py | 2 +- src/transformers/models/moshi/modeling_moshi.py | 2 +- src/transformers/models/nemotron/modeling_nemotron.py | 4 ++-- src/transformers/models/olmoe/modeling_olmoe.py | 2 +- src/transformers/models/phimoe/modeling_phimoe.py | 2 +- src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py | 2 +- src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py | 2 +- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 2 +- .../models/recurrent_gemma/modeling_recurrent_gemma.py | 2 +- 22 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index d6575a8751d0..64148e2457f0 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -100,7 +100,7 @@ def forward(self, x, position_ids): # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + device_type = device_type if device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 68ba9629f9a1..cf27c1c8a3ff 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -64,7 +64,7 @@ def forward(self, x, position_ids, seq_len=None): # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + device_type = device_type if device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) @@ -389,7 +389,7 @@ def forward( input_dtype = query_states.dtype device_type = ( query_states.device.type - if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + if query_states.device.type != "mps" else "cpu" ) if input_dtype == torch.float32: diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 8cc48db87843..494995596cbf 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -308,7 +308,7 @@ def forward( input_dtype = query_states.dtype device_type = ( query_states.device.type - if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + if query_states.device.type != "mps" else "cpu" ) if input_dtype == torch.float32: diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index 8606523936d8..f1c242029284 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -241,7 +241,7 @@ def forward( input_dtype = query_states.dtype device_type = ( query_states.device.type - if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + if query_states.device.type != "mps" else "cpu" ) if input_dtype == torch.float32: diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 45c03e1a7d15..f38714c13c34 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -291,7 +291,7 @@ def reshape(x: torch.Tensor) -> torch.Tensor: device_type = ( query_states.device.type - if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + if query_states.device.type != "mps" else "cpu" ) if query_states.dtype == torch.float32: diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index b216ebe8b7f2..37bd73d9dff3 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -461,7 +461,7 @@ def forward( input_dtype = query_layer.dtype device_type = ( query_layer.device.type - if isinstance(query_layer.device.type, str) and query_layer.device.type != "mps" + if query_layer.device.type != "mps" else "cpu" ) if input_dtype == torch.float32: diff --git a/src/transformers/models/esm/modeling_esmfold.py b/src/transformers/models/esm/modeling_esmfold.py index 910333eb480f..cff95dc624d5 100644 --- a/src/transformers/models/esm/modeling_esmfold.py +++ b/src/transformers/models/esm/modeling_esmfold.py @@ -890,7 +890,7 @@ def forward( b = b * self.sigmoid(self.linear_b_g(z)) b = b * self.linear_b_p(z) - device_type = a.device.type if isinstance(a.device.type, str) and a.device.type != "mps" else "cpu" + device_type = a.device.type if a.device.type != "mps" else "cpu" if is_fp16_enabled(device_type): with torch.autocast(device_type=device_type, enabled=False): x = self._combine_projections(a.float(), b.float()) @@ -1505,7 +1505,7 @@ def forward( z[0] = z[0].cpu() # [*, H, N_res, N_res] - device_type = q.device.type if isinstance(q.device.type, str) and q.device.type != "mps" else "cpu" + device_type = q.device.type if q.device.type != "mps" else "cpu" if is_fp16_enabled(device_type): with torch.autocast(device_type=device_type, enabled=False): a = torch.matmul( diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 0c8125c5cf35..06d6234cbcbe 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -490,7 +490,7 @@ def forward( input_dtype = query_layer.dtype device_type = ( query_layer.device.type - if isinstance(query_layer.device.type, str) and query_layer.device.type != "mps" + if query_layer.device.type != "mps" else "cpu" ) if input_dtype == torch.float32: diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 63c97dd5a7b8..b90fdfe8acc1 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -343,7 +343,7 @@ def forward( # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query.dtype - device_type = query.device.type if isinstance(query.device.type, str) and query.device.type != "mps" else "cpu" + device_type = query.device.type if query.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = ( diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index ab108c7a80ae..8ac65c7d1ae8 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -323,7 +323,7 @@ def forward( # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) - device_type = query.device.type if isinstance(query.device.type, str) and query.device.type != "mps" else "cpu" + device_type = query.device.type if query.device.type != "mps" else "cpu" if query.dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = ( diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 9726f882e25a..4388fad01f60 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -355,7 +355,7 @@ def forward( # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query.dtype - device_type = query.device.type if isinstance(query.device.type, str) and query.device.type != "mps" else "cpu" + device_type = query.device.type if query.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = ( diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index a127d89f8c3b..4566da2129f9 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -414,7 +414,7 @@ def forward( input_dtype = query_states.dtype device_type = ( query_states.device.type - if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + if query_states.device.type != "mps" else "cpu" ) if input_dtype == torch.float32: diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 6616a337858d..9ef99efe65e8 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -712,7 +712,7 @@ def forward( input_dtype = query_states.dtype device_type = ( query_states.device.type - if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + if query_states.device.type != "mps" else "cpu" ) if input_dtype == torch.float32: diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 641333539a8b..9e9d56b9d5e7 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -663,7 +663,7 @@ def forward( input_dtype = query_states.dtype device_type = ( query_states.device.type - if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + if query_states.device.type != "mps" else "cpu" ) if input_dtype == torch.float32: diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 489a9137a454..44e8bc863e12 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -596,7 +596,7 @@ def forward( input_dtype = query_states.dtype device_type = ( query_states.device.type - if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + if query_states.device.type != "mps" else "cpu" ) if input_dtype == torch.float32: diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index e582e172dc51..8b4300fe3baa 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -77,7 +77,7 @@ def __init__( super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype) def forward(self, input: Tensor) -> Tensor: - device_type = input.device.type if isinstance(input.device.type, str) and input.device.type != "mps" else "cpu" + device_type = input.device.type if input.device.type != "mps" else "cpu" args = _cast_if_autocast_enabled( device_type, input, self.normalized_shape, self.weight + 1, self.bias, self.eps ) @@ -355,7 +355,7 @@ def forward( input_dtype = query_states.dtype device_type = ( query_states.device.type - if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + if query_states.device.type != "mps" else "cpu" ) if input_dtype == torch.float32: diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index b4ebe34b2f87..1ae5cc9c6eec 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -423,7 +423,7 @@ def forward( input_dtype = query_states.dtype device_type = ( query_states.device.type - if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + if query_states.device.type != "mps" else "cpu" ) if input_dtype == torch.float32: diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index dc3ed66966db..347c544e5b7d 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -371,7 +371,7 @@ def forward( input_dtype = query_states.dtype device_type = ( query_states.device.type - if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + if query_states.device.type != "mps" else "cpu" ) if input_dtype == torch.float32: diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index d6f2a0e4b5fc..fd4dc0e2cc81 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -2638,7 +2638,7 @@ def forward(self, x): batch_size, seq_len = x.shape[0], x.shape[1] t = torch.arange(seq_len, device=x.device) device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + device_type = device_type if device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float() freqs = torch.stack((freqs, freqs), dim=-1) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 58b3b87ad381..3d9cb1e91788 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -2906,7 +2906,7 @@ def forward(self, x): batch_size, seq_len = x.shape[0], x.shape[1] t = torch.arange(seq_len, device=x.device) device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + device_type = device_type if device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float() freqs = torch.stack((freqs, freqs), dim=-1) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 8dd67e9f68de..d3db4c5d0025 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -420,7 +420,7 @@ def forward( input_dtype = query_states.dtype device_type = ( query_states.device.type - if isinstance(query_states.device.type, str) and query_states.device.type != "mps" + if query_states.device.type != "mps" else "cpu" ) if input_dtype == torch.float32: diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 47e79f34870a..6723c520f282 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -78,7 +78,7 @@ def forward(self, x, position_ids, seq_len=None): # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + device_type = device_type if device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) From aec8cc41ab076289645f49a55212d619d4fc4b2c Mon Sep 17 00:00:00 2001 From: YAO Matrix Date: Wed, 18 Jun 2025 00:04:35 +0000 Subject: [PATCH 9/9] fix style Signed-off-by: YAO Matrix --- src/transformers/models/dbrx/modeling_dbrx.py | 6 +----- src/transformers/models/diffllama/modeling_diffllama.py | 6 +----- src/transformers/models/diffllama/modular_diffllama.py | 6 +----- src/transformers/models/distilbert/modeling_distilbert.py | 6 +----- src/transformers/models/esm/modeling_esm.py | 6 +----- src/transformers/models/falcon/modeling_falcon.py | 6 +----- src/transformers/models/jamba/modeling_jamba.py | 6 +----- src/transformers/models/jetmoe/modeling_jetmoe.py | 6 +----- src/transformers/models/mimi/modeling_mimi.py | 6 +----- src/transformers/models/moshi/modeling_moshi.py | 6 +----- src/transformers/models/nemotron/modeling_nemotron.py | 6 +----- src/transformers/models/olmoe/modeling_olmoe.py | 6 +----- src/transformers/models/phimoe/modeling_phimoe.py | 6 +----- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 6 +----- 14 files changed, 14 insertions(+), 70 deletions(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index cf27c1c8a3ff..f1fee43a2391 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -387,11 +387,7 @@ def forward( # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query_states.dtype - device_type = ( - query_states.device.type - if query_states.device.type != "mps" - else "cpu" - ) + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = ( diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 494995596cbf..fae9f2dbb95c 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -306,11 +306,7 @@ def forward( # in fp32. (DiffLlamaRMSNorm handles it correctly) input_dtype = query_states.dtype - device_type = ( - query_states.device.type - if query_states.device.type != "mps" - else "cpu" - ) + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = ( diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index f1c242029284..0ff0465c7937 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -239,11 +239,7 @@ def forward( # in fp32. (DiffLlamaRMSNorm handles it correctly) input_dtype = query_states.dtype - device_type = ( - query_states.device.type - if query_states.device.type != "mps" - else "cpu" - ) + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = ( diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index f38714c13c34..28cec74fb3d1 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -289,11 +289,7 @@ def reshape(x: torch.Tensor) -> torch.Tensor: # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) - device_type = ( - query_states.device.type - if query_states.device.type != "mps" - else "cpu" - ) + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if query_states.dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = ( diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 37bd73d9dff3..10db78a67cb1 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -459,11 +459,7 @@ def forward( # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. input_dtype = query_layer.dtype - device_type = ( - query_layer.device.type - if query_layer.device.type != "mps" - else "cpu" - ) + device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = ( diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 06d6234cbcbe..d6634662f30b 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -488,11 +488,7 @@ def forward( # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_layer.dtype - device_type = ( - query_layer.device.type - if query_layer.device.type != "mps" - else "cpu" - ) + device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = ( diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 4566da2129f9..0b93d4484c9f 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -412,11 +412,7 @@ def forward( # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype - device_type = ( - query_states.device.type - if query_states.device.type != "mps" - else "cpu" - ) + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = ( diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 9ef99efe65e8..d08c00c6c6dc 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -710,11 +710,7 @@ def forward( # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query_states.dtype - device_type = ( - query_states.device.type - if query_states.device.type != "mps" - else "cpu" - ) + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = ( diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 9e9d56b9d5e7..45c307b61367 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -661,11 +661,7 @@ def forward( # in fp32. (MimiRMSNorm handles it correctly) input_dtype = query_states.dtype - device_type = ( - query_states.device.type - if query_states.device.type != "mps" - else "cpu" - ) + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = ( diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 44e8bc863e12..6eca81ef4977 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -594,11 +594,7 @@ def forward( # in fp32. (MoshiRMSNorm handles it correctly) input_dtype = query_states.dtype - device_type = ( - query_states.device.type - if query_states.device.type != "mps" - else "cpu" - ) + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = ( diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 8b4300fe3baa..d9d0248a33d4 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -353,11 +353,7 @@ def forward( # in fp32. (NemotronRMSNorm handles it correctly) input_dtype = query_states.dtype - device_type = ( - query_states.device.type - if query_states.device.type != "mps" - else "cpu" - ) + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): # NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4 diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 1ae5cc9c6eec..a9f2a08124e5 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -421,11 +421,7 @@ def forward( # in fp32. (OlmoeRMSNorm handles it correctly) input_dtype = query_states.dtype - device_type = ( - query_states.device.type - if query_states.device.type != "mps" - else "cpu" - ) + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = ( diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 347c544e5b7d..b1651f467b41 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -369,11 +369,7 @@ def forward( # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype - device_type = ( - query_states.device.type - if query_states.device.type != "mps" - else "cpu" - ) + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = ( diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index d3db4c5d0025..cc617533582d 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -418,11 +418,7 @@ def forward( # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype - device_type = ( - query_states.device.type - if query_states.device.type != "mps" - else "cpu" - ) + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = (