From 7396b01f4e2522cfc4e4be26d84242b87b3f4b4f Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Mon, 30 Mar 2026 18:39:52 +0800 Subject: [PATCH] Add GDS support for safetensors loading (HF_ENABLE_GDS=1) --- src/transformers/core_model_loading.py | 7 +- src/transformers/modeling_utils.py | 16 ++- src/transformers/utils/gds_io.py | 160 +++++++++++++++++++++++++ 3 files changed, 181 insertions(+), 2 deletions(-) create mode 100644 src/transformers/utils/gds_io.py diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index ce0f2d9cec9b..fa6ecd4c3605 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -781,7 +781,10 @@ def convert( def _materialize_copy(tensor: torch.Tensor, device=None, dtype=None) -> torch.Tensor: - # This slicing is what actually loads the tensor from the safetensors slice object + # This slicing is what actually loads the tensor from the safetensors slice object. + # GDS-backed slices use the target device hint to DMA directly to the right GPU. + if hasattr(tensor, "_set_target_device"): + tensor._set_target_device(device) tensor = tensor[...] if dtype is not None or device is not None: tensor = tensor.to(device=device, dtype=dtype) @@ -815,6 +818,8 @@ def spawn_tp_materialize( return a Callable that will load the tensor synchronously when called.""" def _job(): + if hasattr(tensor, "_set_target_device"): + tensor._set_target_device(device) return sharding_method.shard_tensor(tensor, tensor_idx=tensor_idx, device=device, dtype=dtype) if thread_pool is not None: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f50774ef8065..2e8711cfde50 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -118,6 +118,7 @@ is_torch_xpu_available, logging, ) +from .utils.gds_io import GdsSafetensorsFile, should_use_gds from .utils.generic import GeneralInterface, is_flash_attention_requested from .utils.hub import DownloadKwargs, create_and_tag_model_card, get_checkpoint_shard_files from .utils.import_utils import ( @@ -4234,8 +4235,21 @@ def _load_pretrained_model( merged_state_dict = state_dict elif checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") and state_dict is None: merged_state_dict = {} + use_gds = ( + should_use_gds() + and load_config.device_map is not None + and any(is_accelerator_device(d) for d in set(load_config.device_map.values())) + ) for file in checkpoint_files: - file_pointer = safe_open(file, framework="pt", device="cpu") + if use_gds: + try: + file_pointer = GdsSafetensorsFile(file) + except Exception: + logger.info("GDS open failed for %s, falling back to safe_open", file) + use_gds = False + file_pointer = safe_open(file, framework="pt", device="cpu") + else: + file_pointer = safe_open(file, framework="pt", device="cpu") all_pointer.add(file_pointer) for k in file_pointer.keys(): merged_state_dict[k] = file_pointer.get_slice(k) # don't materialize yet diff --git a/src/transformers/utils/gds_io.py b/src/transformers/utils/gds_io.py new file mode 100644 index 000000000000..223992810e81 --- /dev/null +++ b/src/transformers/utils/gds_io.py @@ -0,0 +1,160 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +GPU Direct Storage (GDS) utilities for safetensors loading. +""" + +import json +import os +import struct +from typing import Any + +import torch +from safetensors import safe_open + +from .import_utils import is_env_variable_true, is_torch_greater_or_equal + + +# Tensors below this use safe_open instead of cuFile (per-call overhead). +_GDS_MIN_BYTES = 1 * 1024 * 1024 + +_gds_available: bool | None = None + + +def is_gds_available() -> bool: + """Check if ``torch.cuda.gds.GdsFile`` is usable. Requires PyTorch >= 2.10, CUDA >= 12.6.""" + global _gds_available + if _gds_available is not None: + return _gds_available + _gds_available = False + if not is_torch_greater_or_equal("2.10"): + return False + try: + if not torch.cuda.is_available(): + return False + if hasattr(torch._C, "_gds_is_available"): + _gds_available = torch._C._gds_is_available() + else: + from torch.cuda.gds import GdsFile # noqa: F401 + + _gds_available = True + except (ImportError, AttributeError, RuntimeError): + _gds_available = False + return _gds_available + + +def should_use_gds() -> bool: + """``True`` when GDS is available and opted-in via ``HF_ENABLE_GDS=1``.""" + return is_env_variable_true("HF_ENABLE_GDS") and is_gds_available() + + +# Resolve once at class init, not per-tensor +_str_to_torch_dtype: dict[str, torch.dtype] | None = None + + +def _get_dtype_map() -> dict[str, torch.dtype]: + global _str_to_torch_dtype + if _str_to_torch_dtype is None: + from ..modeling_utils import str_to_torch_dtype + + _str_to_torch_dtype = str_to_torch_dtype + return _str_to_torch_dtype + + +class GdsSafetensorsFile: + """GDS-backed safetensors reader — drop-in replacement for ``safe_open()``.""" + + def __init__(self, filename: str): + self.filename = str(filename) + with open(self.filename, "rb") as f: + header_size = struct.unpack(" "GdsSlice": + return GdsSlice(self, name) + + def get_tensor(self, name: str, device: torch.device | None = None) -> torch.Tensor: + meta = self._tensor_meta[name] + if device is not None and device.type == "cuda" and meta["nbytes"] >= _GDS_MIN_BYTES: + tensor = torch.empty(meta["shape"], dtype=self._dtype_map[meta["dtype"]], device=device) + self._gds_file.load_storage(tensor.untyped_storage(), meta["file_offset"]) + return tensor + return self._safe_fp.get_tensor(name) + + def __enter__(self): + return self + + def __exit__(self, *_args): + self._close() + + def __del__(self): + self._close() + + def _close(self): + gds = self.__dict__.pop("_gds_file", None) + safe = self.__dict__.pop("_safe_fp", None) + del gds + if safe is not None: + safe.__exit__(None, None, None) + + +class GdsSlice: + """Lazy tensor reference compatible with ``PySafeSlice``.""" + + __slots__ = ("_gds_file", "_name", "_dtype", "_shape", "_target_device") + + def __init__(self, gds_file: GdsSafetensorsFile, name: str): + meta = gds_file._tensor_meta[name] + self._gds_file = gds_file + self._name = name + self._dtype = meta["dtype"] + self._shape = meta["shape"] + self._target_device: torch.device | None = None + + def _set_target_device(self, device) -> None: + self._target_device = torch.device(device) if device is not None else None + + def get_dtype(self) -> str: + return self._dtype + + def get_shape(self) -> list[int]: + return self._shape + + def __getitem__(self, slices): + tensor = self._gds_file.get_tensor(self._name, self._target_device) + if slices is Ellipsis or slices == (Ellipsis,): + return tensor + return tensor[slices]