diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index d6575a8751d0..64148e2457f0 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -100,7 +100,7 @@ def forward(self, x, position_ids): # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + device_type = device_type if device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 097370a46cd5..f1fee43a2391 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -64,7 +64,7 @@ def forward(self, x, position_ids, seq_len=None): # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + device_type = device_type if device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) @@ -387,9 +387,14 @@ def forward( # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index d34e989986d1..a436ce6d2c03 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -219,7 +219,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea scale_factor /= float(self.layer_idx + 1) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) - with torch.amp.autocast(query.device.type, enabled=False): + with torch.autocast(query.device.type, enabled=False): q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 43e8e4584bdc..fae9f2dbb95c 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -306,9 +306,14 @@ def forward( # in fp32. (DiffLlamaRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index 934baf9c0b1e..0ff0465c7937 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -239,9 +239,14 @@ def forward( # in fp32. (DiffLlamaRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 1a84544bee3f..28cec74fb3d1 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -289,9 +289,14 @@ def reshape(x: torch.Tensor) -> torch.Tensor: # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if query_states.dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index dbf260fb2160..10db78a67cb1 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -459,9 +459,14 @@ def forward( # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. input_dtype = query_layer.dtype + device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/esm/modeling_esmfold.py b/src/transformers/models/esm/modeling_esmfold.py index 3338b3d50073..644fec2e1a32 100644 --- a/src/transformers/models/esm/modeling_esmfold.py +++ b/src/transformers/models/esm/modeling_esmfold.py @@ -133,9 +133,14 @@ class EsmForProteinFoldingOutput(ModelOutput): max_predicted_aligned_error: Optional[torch.FloatTensor] = None -def is_fp16_enabled(): +def is_fp16_enabled(device_type): # Autocast world - fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16 + autocast_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) + fp16_enabled = autocast_dtype == torch.float16 fp16_enabled = fp16_enabled and torch.is_autocast_enabled() return fp16_enabled @@ -885,8 +890,9 @@ def forward( b = b * self.sigmoid(self.linear_b_g(z)) b = b * self.linear_b_p(z) - if is_fp16_enabled(): - with torch.cuda.amp.autocast(enabled=False): + device_type = a.device.type if a.device.type != "mps" else "cpu" + if is_fp16_enabled(device_type): + with torch.autocast(device_type=device_type, enabled=False): x = self._combine_projections(a.float(), b.float()) else: x = self._combine_projections(a, b) @@ -1499,8 +1505,9 @@ def forward( z[0] = z[0].cpu() # [*, H, N_res, N_res] - if is_fp16_enabled(): - with torch.cuda.amp.autocast(enabled=False): + device_type = q.device.type if q.device.type != "mps" else "cpu" + if is_fp16_enabled(device_type): + with torch.autocast(device_type=device_type, enabled=False): a = torch.matmul( permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden] permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res] diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 455afbb21577..d6634662f30b 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -488,9 +488,14 @@ def forward( # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_layer.dtype + device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 584d21c41087..fa98bc3614e5 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -229,7 +229,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea scale_factor /= float(self.layer_idx + 1) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) - with torch.amp.autocast(query.device.type, enabled=False): + with torch.autocast(query.device.type, enabled=False): q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 297c30cb06a1..b90fdfe8acc1 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -343,9 +343,14 @@ def forward( # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query.dtype + device_type = query.device.type if query.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 417eb11ab0ff..8ac65c7d1ae8 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -323,9 +323,14 @@ def forward( # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) + device_type = query.device.type if query.device.type != "mps" else "cpu" if query.dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 1e9c0ef5332f..4388fad01f60 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -355,9 +355,14 @@ def forward( # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query.dtype + device_type = query.device.type if query.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index db5ae763aadd..65d5cbc3df21 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -22,7 +22,6 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.cuda.amp import autocast from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN @@ -280,7 +279,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea scale_factor /= float(self.layer_idx + 1) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) - with autocast(enabled=False): + with torch.autocast(query.device.type, enabled=False): q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 52570d8f7f8e..0b93d4484c9f 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -412,9 +412,14 @@ def forward( # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index ab843099c54b..d08c00c6c6dc 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -710,9 +710,14 @@ def forward( # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 9023c93433ab..45c307b61367 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -661,9 +661,14 @@ def forward( # in fp32. (MimiRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index c0ef9c001471..6eca81ef4977 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -594,9 +594,14 @@ def forward( # in fp32. (MoshiRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 6b5cf370c9da..d9d0248a33d4 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -51,11 +51,17 @@ logger = logging.get_logger(__name__) -def _cast_if_autocast_enabled(*args): +def _cast_if_autocast_enabled(device_type, *args): if not torch.is_autocast_enabled(): return args else: - return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype()) + # NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4 + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) + return torch.amp.autocast_mode._cast(args, device_type, target_dtype) class NemotronLayerNorm1P(nn.LayerNorm): @@ -71,8 +77,11 @@ def __init__( super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype) def forward(self, input: Tensor) -> Tensor: - args = _cast_if_autocast_enabled(input, self.normalized_shape, self.weight + 1, self.bias, self.eps) - with torch.amp.autocast(input.device.type, enabled=False): + device_type = input.device.type if input.device.type != "mps" else "cpu" + args = _cast_if_autocast_enabled( + device_type, input, self.normalized_shape, self.weight + 1, self.bias, self.eps + ) + with torch.autocast(device_type=input.device.type, enabled=False): return F.layer_norm(*args) @@ -344,9 +353,15 @@ def forward( # in fp32. (NemotronRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + # NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4 + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 6dc3c12c1ffb..a9f2a08124e5 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -421,9 +421,14 @@ def forward( # in fp32. (OlmoeRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 2a9240783399..b1651f467b41 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -369,9 +369,14 @@ def forward( # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index c4e151e9ce16..1ccc6ea0bfbc 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -2638,7 +2638,7 @@ def forward(self, x): batch_size, seq_len = x.shape[0], x.shape[1] t = torch.arange(seq_len, device=x.device) device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + device_type = device_type if device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float() freqs = torch.stack((freqs, freqs), dim=-1) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 10edb4e6a439..3779a2ad0615 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -2906,7 +2906,7 @@ def forward(self, x): batch_size, seq_len = x.shape[0], x.shape[1] t = torch.arange(seq_len, device=x.device) device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + device_type = device_type if device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float() freqs = torch.stack((freqs, freqs), dim=-1) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 9882ca447d7e..cc617533582d 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -418,9 +418,14 @@ def forward( # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 47e79f34870a..6723c520f282 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -78,7 +78,7 @@ def forward(self, x, position_ids, seq_len=None): # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + device_type = device_type if device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index 5777053923a3..dfa8bca69ef4 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -462,7 +462,7 @@ def test_mamba2_mixer_train_vs_eval_equivalence(self): config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1) torch.manual_seed(42) - with torch.amp.autocast(device_type=torch_device, dtype=dtype): + with torch.autocast(device_type=torch_device, dtype=dtype): with torch.no_grad(): mixer = Mamba2Mixer(config, layer_idx=0).to(torch_device) hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device=torch_device)