Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
a7abb19
siwtch to device agnostic autocast in nemotron to align xpu behavior w/
yao-matrix May 21, 2025
1130d3b
fix issue
yao-matrix May 22, 2025
ded4b4e
fix style
yao-matrix May 22, 2025
2be4cef
Merge branch 'main' into autocast-xpu
yao-matrix May 22, 2025
e32878c
Merge branch 'main' into autocast-xpu
yao-matrix May 22, 2025
18a3bd8
use torch.cast as other modeling code for decision_transformer&gpt2&i…
yao-matrix May 23, 2025
2b8f5d4
refine
yao-matrix May 23, 2025
86f9f00
Merge branch 'main' into autocast-xpu
yao-matrix May 25, 2025
8a70406
Merge branch 'main' into autocast-xpu
yao-matrix May 26, 2025
cc09702
Merge branch 'main' into autocast-xpu
yao-matrix May 29, 2025
74bf427
update get_autocast_gpu_dtype to device agnostic one
yao-matrix May 29, 2025
7241eb3
fix style
yao-matrix May 29, 2025
3b40d0d
Merge branch 'main' into autocast-xpu
yao-matrix May 30, 2025
e1c3eb7
Merge branch 'main' into autocast-xpu
yao-matrix Jun 3, 2025
38932c7
Merge branch 'main' into autocast-xpu
yao-matrix Jun 6, 2025
0485942
Merge branch 'main' into autocast-xpu
yao-matrix Jun 9, 2025
7ba99c7
Merge branch 'main' into autocast-xpu
yao-matrix Jun 9, 2025
bd84a90
Merge branch 'main' into autocast-xpu
yao-matrix Jun 11, 2025
58400de
Merge branch 'main' into autocast-xpu
yao-matrix Jun 12, 2025
82b840c
Merge branch 'main' into autocast-xpu
yao-matrix Jun 16, 2025
9d449b8
Merge branch 'main' into autocast-xpu
yao-matrix Jun 17, 2025
2f2cc64
fix comments
yao-matrix Jun 18, 2025
aec8cc4
fix style
yao-matrix Jun 18, 2025
3383a22
Merge branch 'main' into autocast-xpu
yao-matrix Jun 18, 2025
71d907e
Merge branch 'main' into autocast-xpu
ydshieh Jun 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def forward(self, x, position_ids):
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
device_type = device_type if device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
Expand Down
9 changes: 7 additions & 2 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def forward(self, x, position_ids, seq_len=None):
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
device_type = device_type if device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
Expand Down Expand Up @@ -387,9 +387,14 @@ def forward(
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea
scale_factor /= float(self.layer_idx + 1)

# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with torch.amp.autocast(query.device.type, enabled=False):
with torch.autocast(query.device.type, enabled=False):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

align to other modeling code in models directory

q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/diffllama/modeling_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,14 @@ def forward(
# in fp32. (DiffLlamaRMSNorm handles it correctly)

input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/diffllama/modular_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,14 @@ def forward(
# in fp32. (DiffLlamaRMSNorm handles it correctly)

input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,14 @@ def reshape(x: torch.Tensor) -> torch.Tensor:
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)

device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if query_states.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/esm/modeling_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,9 +459,14 @@ def forward(
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32.
input_dtype = query_layer.dtype
device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
Expand Down
19 changes: 13 additions & 6 deletions src/transformers/models/esm/modeling_esmfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,14 @@ class EsmForProteinFoldingOutput(ModelOutput):
max_predicted_aligned_error: Optional[torch.FloatTensor] = None


def is_fp16_enabled():
def is_fp16_enabled(device_type):
# Autocast world
fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
autocast_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
fp16_enabled = autocast_dtype == torch.float16
fp16_enabled = fp16_enabled and torch.is_autocast_enabled()

return fp16_enabled
Expand Down Expand Up @@ -885,8 +890,9 @@ def forward(
b = b * self.sigmoid(self.linear_b_g(z))
b = b * self.linear_b_p(z)

if is_fp16_enabled():
with torch.cuda.amp.autocast(enabled=False):
device_type = a.device.type if a.device.type != "mps" else "cpu"
if is_fp16_enabled(device_type):
with torch.autocast(device_type=device_type, enabled=False):
x = self._combine_projections(a.float(), b.float())
else:
x = self._combine_projections(a, b)
Expand Down Expand Up @@ -1499,8 +1505,9 @@ def forward(
z[0] = z[0].cpu()

# [*, H, N_res, N_res]
if is_fp16_enabled():
with torch.cuda.amp.autocast(enabled=False):
device_type = q.device.type if q.device.type != "mps" else "cpu"
if is_fp16_enabled(device_type):
with torch.autocast(device_type=device_type, enabled=False):
a = torch.matmul(
permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,9 +488,14 @@ def forward(
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_layer.dtype
device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea
scale_factor /= float(self.layer_idx + 1)

# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with torch.amp.autocast(query.device.type, enabled=False):
with torch.autocast(query.device.type, enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,14 @@ def forward(
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query.dtype
device_type = query.device.type if query.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,14 @@ def forward(
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)

device_type = query.device.type if query.device.type != "mps" else "cpu"
if query.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,14 @@ def forward(
# in fp32. (LlamaRMSNorm handles it correctly)

input_dtype = query.dtype
device_type = query.device.type if query.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/imagegpt/modeling_imagegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.cuda.amp import autocast
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
Expand Down Expand Up @@ -280,7 +279,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea
scale_factor /= float(self.layer_idx + 1)

# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with autocast(enabled=False):
with torch.autocast(query.device.type, enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,14 @@ def forward(
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/jetmoe/modeling_jetmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,9 +710,14 @@ def forward(
# in fp32. (LlamaRMSNorm handles it correctly)

input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/mimi/modeling_mimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,9 +661,14 @@ def forward(
# in fp32. (MimiRMSNorm handles it correctly)

input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/moshi/modeling_moshi.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,14 @@ def forward(
# in fp32. (MoshiRMSNorm handles it correctly)

input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
Expand Down
25 changes: 20 additions & 5 deletions src/transformers/models/nemotron/modeling_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,17 @@
logger = logging.get_logger(__name__)


def _cast_if_autocast_enabled(*args):
def _cast_if_autocast_enabled(device_type, *args):
if not torch.is_autocast_enabled():
return args
else:
return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype())
# NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
return torch.amp.autocast_mode._cast(args, device_type, target_dtype)


class NemotronLayerNorm1P(nn.LayerNorm):
Expand All @@ -71,8 +77,11 @@ def __init__(
super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype)

def forward(self, input: Tensor) -> Tensor:
args = _cast_if_autocast_enabled(input, self.normalized_shape, self.weight + 1, self.bias, self.eps)
with torch.amp.autocast(input.device.type, enabled=False):
device_type = input.device.type if input.device.type != "mps" else "cpu"
args = _cast_if_autocast_enabled(
device_type, input, self.normalized_shape, self.weight + 1, self.bias, self.eps
)
with torch.autocast(device_type=input.device.type, enabled=False):
return F.layer_norm(*args)


Expand Down Expand Up @@ -344,9 +353,15 @@ def forward(
# in fp32. (NemotronRMSNorm handles it correctly)

input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/olmoe/modeling_olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,14 @@ def forward(
# in fp32. (OlmoeRMSNorm handles it correctly)

input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/phimoe/modeling_phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,9 +369,14 @@ def forward(
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
Expand Down
Loading