Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,10 @@ def convert(


def _materialize_copy(tensor: torch.Tensor, device=None, dtype=None) -> torch.Tensor:
# This slicing is what actually loads the tensor from the safetensors slice object
# This slicing is what actually loads the tensor from the safetensors slice object.
# GDS-backed slices use the target device hint to DMA directly to the right GPU.
if hasattr(tensor, "_set_target_device"):
tensor._set_target_device(device)
tensor = tensor[...]
if dtype is not None or device is not None:
tensor = tensor.to(device=device, dtype=dtype)
Expand Down Expand Up @@ -815,6 +818,8 @@ def spawn_tp_materialize(
return a Callable that will load the tensor synchronously when called."""

def _job():
if hasattr(tensor, "_set_target_device"):
tensor._set_target_device(device)
return sharding_method.shard_tensor(tensor, tensor_idx=tensor_idx, device=device, dtype=dtype)

if thread_pool is not None:
Expand Down
16 changes: 15 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
is_torch_xpu_available,
logging,
)
from .utils.gds_io import GdsSafetensorsFile, should_use_gds
from .utils.generic import GeneralInterface, is_flash_attention_requested
from .utils.hub import DownloadKwargs, create_and_tag_model_card, get_checkpoint_shard_files
from .utils.import_utils import (
Expand Down Expand Up @@ -4234,8 +4235,21 @@ def _load_pretrained_model(
merged_state_dict = state_dict
elif checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") and state_dict is None:
merged_state_dict = {}
use_gds = (
should_use_gds()
and load_config.device_map is not None
and any(is_accelerator_device(d) for d in set(load_config.device_map.values()))
)
for file in checkpoint_files:
file_pointer = safe_open(file, framework="pt", device="cpu")
if use_gds:
try:
file_pointer = GdsSafetensorsFile(file)
except Exception:
logger.info("GDS open failed for %s, falling back to safe_open", file)
use_gds = False
file_pointer = safe_open(file, framework="pt", device="cpu")
else:
file_pointer = safe_open(file, framework="pt", device="cpu")
all_pointer.add(file_pointer)
for k in file_pointer.keys():
merged_state_dict[k] = file_pointer.get_slice(k) # don't materialize yet
Expand Down
160 changes: 160 additions & 0 deletions src/transformers/utils/gds_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
GPU Direct Storage (GDS) utilities for safetensors loading.
"""

import json
import os
import struct
from typing import Any

import torch
from safetensors import safe_open

from .import_utils import is_env_variable_true, is_torch_greater_or_equal


# Tensors below this use safe_open instead of cuFile (per-call overhead).
_GDS_MIN_BYTES = 1 * 1024 * 1024

_gds_available: bool | None = None


def is_gds_available() -> bool:
"""Check if ``torch.cuda.gds.GdsFile`` is usable. Requires PyTorch >= 2.10, CUDA >= 12.6."""
global _gds_available
if _gds_available is not None:
return _gds_available
_gds_available = False
if not is_torch_greater_or_equal("2.10"):
return False
try:
if not torch.cuda.is_available():
return False
if hasattr(torch._C, "_gds_is_available"):
_gds_available = torch._C._gds_is_available()
else:
from torch.cuda.gds import GdsFile # noqa: F401

_gds_available = True
except (ImportError, AttributeError, RuntimeError):
_gds_available = False
return _gds_available


def should_use_gds() -> bool:
"""``True`` when GDS is available and opted-in via ``HF_ENABLE_GDS=1``."""
return is_env_variable_true("HF_ENABLE_GDS") and is_gds_available()


# Resolve once at class init, not per-tensor
_str_to_torch_dtype: dict[str, torch.dtype] | None = None


def _get_dtype_map() -> dict[str, torch.dtype]:
global _str_to_torch_dtype
if _str_to_torch_dtype is None:
from ..modeling_utils import str_to_torch_dtype

_str_to_torch_dtype = str_to_torch_dtype
return _str_to_torch_dtype


class GdsSafetensorsFile:
"""GDS-backed safetensors reader — drop-in replacement for ``safe_open()``."""

def __init__(self, filename: str):
self.filename = str(filename)
with open(self.filename, "rb") as f:
header_size = struct.unpack("<Q", f.read(8))[0]
header = json.loads(f.read(header_size))
data_offset = 8 + header_size

self._tensor_meta: dict[str, dict[str, Any]] = {}
for name, meta in header.items():
if name == "__metadata__":
continue
start, end = meta["data_offsets"]
self._tensor_meta[name] = {
"dtype": meta["dtype"],
"shape": meta["shape"],
"file_offset": data_offset + start,
"nbytes": end - start,
}

from torch.cuda.gds import GdsFile

self._gds_file = GdsFile(self.filename, os.O_RDONLY)
self._safe_fp = safe_open(self.filename, framework="pt", device="cpu")
self._dtype_map = _get_dtype_map()

def keys(self):
return self._tensor_meta.keys()

def get_slice(self, name: str) -> "GdsSlice":
return GdsSlice(self, name)

def get_tensor(self, name: str, device: torch.device | None = None) -> torch.Tensor:
meta = self._tensor_meta[name]
if device is not None and device.type == "cuda" and meta["nbytes"] >= _GDS_MIN_BYTES:
tensor = torch.empty(meta["shape"], dtype=self._dtype_map[meta["dtype"]], device=device)
self._gds_file.load_storage(tensor.untyped_storage(), meta["file_offset"])
return tensor
return self._safe_fp.get_tensor(name)

def __enter__(self):
return self

def __exit__(self, *_args):
self._close()

def __del__(self):
self._close()

def _close(self):
gds = self.__dict__.pop("_gds_file", None)
safe = self.__dict__.pop("_safe_fp", None)
del gds
if safe is not None:
safe.__exit__(None, None, None)


class GdsSlice:
"""Lazy tensor reference compatible with ``PySafeSlice``."""

__slots__ = ("_gds_file", "_name", "_dtype", "_shape", "_target_device")

def __init__(self, gds_file: GdsSafetensorsFile, name: str):
meta = gds_file._tensor_meta[name]
self._gds_file = gds_file
self._name = name
self._dtype = meta["dtype"]
self._shape = meta["shape"]
self._target_device: torch.device | None = None

def _set_target_device(self, device) -> None:
self._target_device = torch.device(device) if device is not None else None

def get_dtype(self) -> str:
return self._dtype

def get_shape(self) -> list[int]:
return self._shape

def __getitem__(self, slices):
tensor = self._gds_file.get_tensor(self._name, self._target_device)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean you pull the full tensor for each slice that is requested? What does your memory footprint on device look like once the model is loaded?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure your question. The purpose is to load tensors via GDS api, which works best with aligned file offsets.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you guarantee file offsets are aligned here? safetensors files aren't written with that constraint in mind, you need to do some extra processing (we're thinking of supporting writing aligned offsets, but it's tricky wrt backwards compatibility).

What I'm asking, is that from your implementation, it seems you're loading the full tensor self._name on each call to GdsSlice.__getitem__. That is why I asked what the memory footprint (total used memory on device) looks like. If you can run nvidia-smi after loading the model, that'd be a good test to see if that happens.

if slices is Ellipsis or slices == (Ellipsis,):
return tensor
return tensor[slices]
Loading