diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index c3cc4579e5c6..bed115e72d3d 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -299,7 +299,7 @@ def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]: if isinstance(tensor, DTensor): local_tensor = tensor.to_local() - return tensor.device, local_tensor.storage().data_ptr(), tensor.nbytes + return tensor.device, local_tensor.untyped_storage().data_ptr(), tensor.untyped_storage().nbytes() if tensor.device.type == "xla" and is_torch_xla_available(): # NOTE: xla tensors dont have storage