diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 863242a695c6..6e70be2798ba 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -3207,6 +3207,8 @@ def get_device_properties() -> DeviceProperties: if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: import torch + if not torch.cuda.is_available(): + return (torch_device, None, None) major, minor = torch.cuda.get_device_capability() if IS_ROCM_SYSTEM: return ("rocm", major, minor)