Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 62 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from huggingface_hub import create_repo, is_offline_mode, split_torch_state_dict_into_shards
from packaging import version
from safetensors import safe_open
from safetensors.torch import load as _safe_load_bytes
from safetensors.torch import save_file as safe_save_file
from torch import Tensor, nn
from torch.distributions import constraints
Expand Down Expand Up @@ -178,6 +179,7 @@ class LoadStateDictConfig:
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None
weights_only: bool = True
weight_mapping: list[WeightConverter | WeightRenaming] | None = None
disable_mmap: bool | None = None

@property
def is_quantized(self) -> bool:
Expand Down Expand Up @@ -288,14 +290,54 @@ def get_state_dict_dtype(state_dict):
}


def _is_on_hf_mount(path: "str | os.PathLike") -> bool:
"""True if `path` lives on an hf-mount FUSE filesystem (device string 'hf-mount').

hf-mount's mmap + readahead interaction deadlocks under parallel page-faults,
so callers should load the file into memory instead. Linux-only; returns False
on other platforms.
"""
if not sys.platform.startswith("linux"):
return False
try:
real = os.path.realpath(os.fspath(path))
with open("/proc/mounts", encoding="utf-8") as fh:
entries = sorted(
((p[0], p[1]) for p in (l.split() for l in fh) if len(p) >= 2),
key=lambda e: len(e[1]),
reverse=True,
)
for dev, mp in entries:
if real == mp or real.startswith(mp.rstrip("/") + "/"):
return dev == "hf-mount"
except (OSError, ValueError):
pass
return False


def load_state_dict(
checkpoint_file: str | os.PathLike, map_location: str | torch.device = "cpu", weights_only: bool = True
checkpoint_file: str | os.PathLike,
map_location: str | torch.device = "cpu",
weights_only: bool = True,
disable_mmap: bool | None = None,
) -> dict[str, torch.Tensor]:
"""
Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.

When `disable_mmap` is True, safetensors files are read fully into memory instead of
being memory-mapped. When `disable_mmap` is None (default), it is auto-detected to True
on hf-mount FUSE filesystems (see `_is_on_hf_mount`).
"""
if disable_mmap is None:
disable_mmap = _is_on_hf_mount(checkpoint_file)
# Use safetensors if possible
if checkpoint_file.endswith(".safetensors"):
if disable_mmap and map_location != "meta":
with open(checkpoint_file, "rb") as _fh:
state_dict = _safe_load_bytes(_fh.read())
if map_location != "cpu":
state_dict = {k: v.to(map_location) for k, v in state_dict.items()}
return state_dict
with safe_open(checkpoint_file, framework="pt") as f:
state_dict = {}
for k in f.keys():
Expand Down Expand Up @@ -3699,6 +3741,7 @@ def from_pretrained(
use_safetensors: bool | None = None,
weights_only: bool = True,
fusion_config: dict[str, bool | dict[str, Any]] | None = None,
disable_mmap: bool | None = None,
**kwargs,
) -> SpecificPreTrainedModelType:
r"""
Expand Down Expand Up @@ -3877,6 +3920,12 @@ def from_pretrained(
Indicates whether unpickler should be restricted to loading only tensors, primitive types,
dictionaries and any types added via torch.serialization.add_safe_globals().
When set to False, we can load wrapper tensor subclass weights.
disable_mmap (`bool`, *optional*):
Whether to disable memory mapping when loading safetensors checkpoints. When `None` (default),
it is auto-detected to `True` when the checkpoint lives on an `hf-mount` FUSE filesystem
(used by HF Spaces/Endpoints), where mmap + parallel page-faults can deadlock. When `True`,
files are read fully into memory and parsed with `safetensors.torch.load`. When `False`, the
default memory-mapped loader is always used.
fusion_config (`dict[str, bool | dict[str, Any]]`, *optional*):
Optional fusion configuration applied before model instantiation. Each key enables a fusion family and
its value can either be `True` to enable that fusion with default options or a dictionary of
Expand Down Expand Up @@ -4156,6 +4205,7 @@ def from_pretrained(
weight_mapping=weight_conversions,
use_safetensors=use_safetensors,
download_kwargs=download_kwargs,
disable_mmap=disable_mmap,
)
loading_info, disk_offload_index = cls._load_pretrained_model(model, state_dict, checkpoint_files, load_config)
loading_info = cls._finalize_model_loading(model, load_config, loading_info)
Expand Down Expand Up @@ -4246,7 +4296,12 @@ def _load_pretrained_model(
merged_state_dict = {}
for ckpt_file in checkpoint_files:
merged_state_dict.update(
load_state_dict(ckpt_file, map_location="cpu", weights_only=load_config.weights_only)
load_state_dict(
ckpt_file,
map_location="cpu",
weights_only=load_config.weights_only,
disable_mmap=load_config.disable_mmap,
)
)
state_dict = merged_state_dict
error_msgs, missing_keys = _load_state_dict_into_zero3_model(model, state_dict, load_config)
Expand All @@ -4265,6 +4320,10 @@ def _load_pretrained_model(
elif checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") and state_dict is None:
merged_state_dict = {}
for file in checkpoint_files:
if load_config.disable_mmap or _is_on_hf_mount(file):
with open(file, "rb") as _fh:
merged_state_dict.update(_safe_load_bytes(_fh.read()))
continue
file_pointer = safe_open(file, framework="pt", device="cpu")
all_pointer.add(file_pointer)
for k in file_pointer.keys():
Expand All @@ -4273,7 +4332,7 @@ def _load_pretrained_model(
elif checkpoint_files is not None:
merged_state_dict = {}
for ckpt_file in checkpoint_files:
merged_state_dict.update(load_state_dict(ckpt_file))
merged_state_dict.update(load_state_dict(ckpt_file, disable_mmap=load_config.disable_mmap))
else:
raise ValueError("Neither a state dict nor checkpoint files were found.")

Expand Down
64 changes: 64 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3451,3 +3451,67 @@ def test_vision_language_model(self):
assert image_encoder is model.model.vision_tower, (
f"LLaVA get_encoder(modality='image') should return vision_tower, got {type(image_encoder)}"
)


@require_torch
class DisableMmapLoadingTest(unittest.TestCase):
"""Tests for the `disable_mmap` kwarg in `load_state_dict` and the `_is_on_hf_mount` helper."""

def _fake_open_factory(self, proc_mounts_contents):
"""Return a patched `open` that serves `proc_mounts_contents` for `/proc/mounts` and defers otherwise."""
import builtins

real_open = builtins.open

def fake_open(path, *args, **kwargs):
if path == "/proc/mounts":
import io

return io.StringIO(proc_mounts_contents)
return real_open(path, *args, **kwargs)

return fake_open

def test_is_on_hf_mount_linux_match(self):
from transformers.modeling_utils import _is_on_hf_mount

mounts = (
"proc /proc proc rw,nosuid,nodev,noexec,relatime 0 0\n"
"hf-mount /data fuse.hf-mount rw,nosuid,nodev,relatime,user_id=0 0 0\n"
)
with patch("sys.platform", "linux"), patch("builtins.open", self._fake_open_factory(mounts)):
self.assertTrue(_is_on_hf_mount("/data/model.safetensors"))

def test_is_on_hf_mount_no_match(self):
from transformers.modeling_utils import _is_on_hf_mount

mounts = "proc /proc proc rw,nosuid,nodev,noexec,relatime 0 0\n/dev/nvme0n1p1 /data ext4 rw,relatime 0 0\n"
with patch("sys.platform", "linux"), patch("builtins.open", self._fake_open_factory(mounts)):
self.assertFalse(_is_on_hf_mount("/data/model.safetensors"))

def test_is_on_hf_mount_non_linux(self):
from transformers.modeling_utils import _is_on_hf_mount

with patch("sys.platform", "darwin"):
self.assertFalse(_is_on_hf_mount("/data/model.safetensors"))

def test_load_state_dict_disable_mmap_explicit(self):
import torch
from safetensors.torch import save_file as safe_save_file

from transformers.modeling_utils import load_state_dict

state_dict = {
"weight": torch.arange(12, dtype=torch.float32).reshape(3, 4),
"bias": torch.tensor([1.0, 2.0, 3.0]),
}
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_path = os.path.join(tmpdir, "model.safetensors")
safe_save_file(state_dict, ckpt_path)

loaded_mmap = load_state_dict(ckpt_path, disable_mmap=False)
loaded_no_mmap = load_state_dict(ckpt_path, disable_mmap=True)

self.assertEqual(set(loaded_mmap.keys()), set(loaded_no_mmap.keys()))
for k in loaded_mmap:
torch.testing.assert_close(loaded_mmap[k], loaded_no_mmap[k])
Loading