Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build-ci-docker-images.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,4 @@ jobs:
slack_channel: "#transformers-ci-circleci-images"
title: 🤗 New docker images for CircleCI are pushed.
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,7 @@
"is_tokenizers_available",
"is_torch_available",
"is_torch_mlu_available",
"is_torch_musa_available",
"is_torch_neuroncore_available",
"is_torch_npu_available",
"is_torch_tpu_available",
Expand Down Expand Up @@ -5706,6 +5707,7 @@
is_tokenizers_available,
is_torch_available,
is_torch_mlu_available,
is_torch_musa_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tpu_available,
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
is_torch_cuda_available,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_musa_available,
is_torch_npu_available,
is_torch_xpu_available,
logging,
Expand Down Expand Up @@ -873,6 +874,8 @@ def __init__(
self.device = torch.device("cpu")
elif is_torch_mlu_available():
self.device = torch.device(f"mlu:{device}")
elif is_torch_musa_available():
self.device = torch.device(f"musa:{device}")
elif is_torch_cuda_available():
self.device = torch.device(f"cuda:{device}")
elif is_torch_npu_available():
Expand Down Expand Up @@ -1042,6 +1045,9 @@ def device_placement(self):
elif self.device.type == "mlu":
with torch.mlu.device(self.device):
yield
elif self.device.type == "musa":
with torch.musa.device(self.device):
yield
else:
yield

Expand Down
20 changes: 20 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@
is_torch_compile_available,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_musa_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_xla_available,
Expand Down Expand Up @@ -2876,6 +2877,17 @@ def _load_rng_state(self, checkpoint):
f"Didn't manage to set back the RNG states of the MLU because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted."
)
if is_torch_musa_available():
Comment thread
fmo-mt marked this conversation as resolved.
Outdated
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
torch.musa.set_rng_state_all(checkpoint_rng_state["musa"])
else:
try:
torch.musa.set_rng_state(checkpoint_rng_state["musa"])
except Exception as e:
logger.info(
f"Didn't manage to set back the RNG states of the MUSA because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted."
)

def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
Expand Down Expand Up @@ -2964,6 +2976,12 @@ def _save_rng_state(self, output_dir):
else:
rng_states["mlu"] = torch.mlu.random.get_rng_state()

if is_torch_musa_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
rng_states["musa"] = torch.musa.get_rng_state_all()
else:
rng_states["musa"] = torch.musa.get_rng_state()

# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
# not yet exist.
os.makedirs(output_dir, exist_ok=True)
Expand Down Expand Up @@ -3333,6 +3351,8 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
torch.xpu.empty_cache()
elif is_torch_mlu_available():
torch.mlu.empty_cache()
elif is_torch_musa_available():
torch.musa.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_mps_available(min_version="2.0"):
Expand Down
15 changes: 14 additions & 1 deletion src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
is_torch_cuda_available,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_musa_available,
is_torch_npu_available,
is_torch_xla_available,
is_torch_xpu_available,
Expand Down Expand Up @@ -108,6 +109,8 @@ def set_seed(seed: int, deterministic: bool = False):
torch.use_deterministic_algorithms(True)
if is_torch_mlu_available():
torch.mlu.manual_seed_all(seed)
if is_torch_musa_available():
torch.musa.manual_seed_all(seed)
if is_torch_npu_available():
torch.npu.manual_seed_all(seed)
if is_torch_xpu_available():
Expand Down Expand Up @@ -464,7 +467,7 @@ def __init__(self, skip_memory_metrics=False):

import psutil # noqa

if is_torch_cuda_available() or is_torch_mlu_available():
if is_torch_cuda_available() or is_torch_mlu_available() or is_torch_musa_available():
import torch

self.torch = torch
Expand Down Expand Up @@ -540,6 +543,9 @@ def start(self):
elif is_torch_mlu_available():
self.torch.mlu.reset_peak_memory_stats()
self.torch.mlu.empty_cache()
elif is_torch_musa_available():
self.torch.musa.reset_peak_memory_stats()
self.torch.musa.empty_cache()
elif is_torch_xpu_available():
self.torch.xpu.reset_peak_memory_stats()
self.torch.xpu.empty_cache()
Expand All @@ -555,6 +561,8 @@ def start(self):
self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated()
elif is_torch_mlu_available():
self.gpu_mem_used_at_start = self.torch.mlu.memory_allocated()
elif is_torch_musa_available():
self.gpu_mem_used_at_start = self.torch.musa.memory_allocated()
elif is_torch_xpu_available():
self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated()
elif is_torch_npu_available():
Expand Down Expand Up @@ -588,6 +596,8 @@ def stop(self, stage):
self.torch.cuda.empty_cache()
elif is_torch_mlu_available():
self.torch.mlu.empty_cache()
elif is_torch_musa_available():
self.torch.musa.empty_cache()
elif is_torch_xpu_available():
self.torch.xpu.empty_cache()
elif is_torch_npu_available():
Expand All @@ -608,6 +618,9 @@ def stop(self, stage):
elif is_torch_mlu_available():
self.gpu_mem_used_now = self.torch.mlu.memory_allocated()
self.gpu_mem_used_peak = self.torch.mlu.max_memory_allocated()
elif is_torch_musa_available():
self.gpu_mem_used_now = self.torch.musa.memory_allocated()
self.gpu_mem_used_peak = self.torch.musa.max_memory_allocated()
elif is_torch_xpu_available():
self.gpu_mem_used_now = self.torch.xpu.memory_allocated()
self.gpu_mem_used_peak = self.torch.xpu.max_memory_allocated()
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
is_torch_bf16_gpu_available,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_musa_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tf32_available,
Expand Down Expand Up @@ -1089,7 +1090,7 @@ class TrainingArguments:
default=None,
metadata={
"help": "The backend to be used for distributed training",
"choices": ["nccl", "gloo", "mpi", "ccl", "hccl", "cncl"],
"choices": ["nccl", "gloo", "mpi", "ccl", "hccl", "cncl", "mccl"],
},
)
tpu_num_cores: Optional[int] = field(
Expand Down Expand Up @@ -2200,6 +2201,9 @@ def _setup_devices(self) -> "torch.device":
elif is_torch_mlu_available():
device = torch.device("mlu:0")
torch.mlu.set_device(device)
elif is_torch_musa_available():
device = torch.device("musa:0")
torch.musa.set_device(device)
elif is_torch_npu_available():
device = torch.device("npu:0")
torch.npu.set_device(device)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@
is_torch_fx_proxy,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_musa_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_sdpa_available,
Expand Down
23 changes: 23 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,29 @@ def is_torch_mlu_available(check_device=False):
return hasattr(torch, "mlu") and torch.mlu.is_available()


@lru_cache()
def is_torch_musa_available(check_device=False):
"Checks if `torch_musa` is installed and potentially if a MUSA is in the environment"
if not _torch_available or importlib.util.find_spec("torch_musa") is None:
return False

import torch
import torch_musa # noqa: F401

torch_musa_min_version = "0.33.0"
if _accelerate_available and version.parse(_accelerate_version) < version.parse(torch_musa_min_version):
return False

if check_device:
try:
# Will raise a RuntimeError if no MUSA is found
_ = torch.musa.device_count()
return torch.musa.is_available()
except RuntimeError:
return False
return hasattr(torch, "musa") and torch.musa.is_available()


def is_torchdynamo_available():
if not is_torch_available():
return False
Expand Down