diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index db2ef1b3323a..eb092019b678 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -36,6 +36,7 @@ from huggingface_hub import create_repo, is_offline_mode, split_torch_state_dict_into_shards from packaging import version from safetensors import safe_open +from safetensors.torch import load as _safe_load_bytes from safetensors.torch import save_file as safe_save_file from torch import Tensor, nn from torch.distributions import constraints @@ -178,6 +179,7 @@ class LoadStateDictConfig: device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None weights_only: bool = True weight_mapping: list[WeightConverter | WeightRenaming] | None = None + disable_mmap: bool | None = None @property def is_quantized(self) -> bool: @@ -288,14 +290,54 @@ def get_state_dict_dtype(state_dict): } +def _is_on_hf_mount(path: "str | os.PathLike") -> bool: + """True if `path` lives on an hf-mount FUSE filesystem (device string 'hf-mount'). + + hf-mount's mmap + readahead interaction deadlocks under parallel page-faults, + so callers should load the file into memory instead. Linux-only; returns False + on other platforms. + """ + if not sys.platform.startswith("linux"): + return False + try: + real = os.path.realpath(os.fspath(path)) + with open("/proc/mounts", encoding="utf-8") as fh: + entries = sorted( + ((p[0], p[1]) for p in (l.split() for l in fh) if len(p) >= 2), + key=lambda e: len(e[1]), + reverse=True, + ) + for dev, mp in entries: + if real == mp or real.startswith(mp.rstrip("/") + "/"): + return dev == "hf-mount" + except (OSError, ValueError): + pass + return False + + def load_state_dict( - checkpoint_file: str | os.PathLike, map_location: str | torch.device = "cpu", weights_only: bool = True + checkpoint_file: str | os.PathLike, + map_location: str | torch.device = "cpu", + weights_only: bool = True, + disable_mmap: bool | None = None, ) -> dict[str, torch.Tensor]: """ Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default. + + When `disable_mmap` is True, safetensors files are read fully into memory instead of + being memory-mapped. When `disable_mmap` is None (default), it is auto-detected to True + on hf-mount FUSE filesystems (see `_is_on_hf_mount`). """ + if disable_mmap is None: + disable_mmap = _is_on_hf_mount(checkpoint_file) # Use safetensors if possible if checkpoint_file.endswith(".safetensors"): + if disable_mmap and map_location != "meta": + with open(checkpoint_file, "rb") as _fh: + state_dict = _safe_load_bytes(_fh.read()) + if map_location != "cpu": + state_dict = {k: v.to(map_location) for k, v in state_dict.items()} + return state_dict with safe_open(checkpoint_file, framework="pt") as f: state_dict = {} for k in f.keys(): @@ -3699,6 +3741,7 @@ def from_pretrained( use_safetensors: bool | None = None, weights_only: bool = True, fusion_config: dict[str, bool | dict[str, Any]] | None = None, + disable_mmap: bool | None = None, **kwargs, ) -> SpecificPreTrainedModelType: r""" @@ -3877,6 +3920,12 @@ def from_pretrained( Indicates whether unpickler should be restricted to loading only tensors, primitive types, dictionaries and any types added via torch.serialization.add_safe_globals(). When set to False, we can load wrapper tensor subclass weights. + disable_mmap (`bool`, *optional*): + Whether to disable memory mapping when loading safetensors checkpoints. When `None` (default), + it is auto-detected to `True` when the checkpoint lives on an `hf-mount` FUSE filesystem + (used by HF Spaces/Endpoints), where mmap + parallel page-faults can deadlock. When `True`, + files are read fully into memory and parsed with `safetensors.torch.load`. When `False`, the + default memory-mapped loader is always used. fusion_config (`dict[str, bool | dict[str, Any]]`, *optional*): Optional fusion configuration applied before model instantiation. Each key enables a fusion family and its value can either be `True` to enable that fusion with default options or a dictionary of @@ -4156,6 +4205,7 @@ def from_pretrained( weight_mapping=weight_conversions, use_safetensors=use_safetensors, download_kwargs=download_kwargs, + disable_mmap=disable_mmap, ) loading_info, disk_offload_index = cls._load_pretrained_model(model, state_dict, checkpoint_files, load_config) loading_info = cls._finalize_model_loading(model, load_config, loading_info) @@ -4246,7 +4296,12 @@ def _load_pretrained_model( merged_state_dict = {} for ckpt_file in checkpoint_files: merged_state_dict.update( - load_state_dict(ckpt_file, map_location="cpu", weights_only=load_config.weights_only) + load_state_dict( + ckpt_file, + map_location="cpu", + weights_only=load_config.weights_only, + disable_mmap=load_config.disable_mmap, + ) ) state_dict = merged_state_dict error_msgs, missing_keys = _load_state_dict_into_zero3_model(model, state_dict, load_config) @@ -4265,6 +4320,10 @@ def _load_pretrained_model( elif checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") and state_dict is None: merged_state_dict = {} for file in checkpoint_files: + if load_config.disable_mmap or _is_on_hf_mount(file): + with open(file, "rb") as _fh: + merged_state_dict.update(_safe_load_bytes(_fh.read())) + continue file_pointer = safe_open(file, framework="pt", device="cpu") all_pointer.add(file_pointer) for k in file_pointer.keys(): @@ -4273,7 +4332,7 @@ def _load_pretrained_model( elif checkpoint_files is not None: merged_state_dict = {} for ckpt_file in checkpoint_files: - merged_state_dict.update(load_state_dict(ckpt_file)) + merged_state_dict.update(load_state_dict(ckpt_file, disable_mmap=load_config.disable_mmap)) else: raise ValueError("Neither a state dict nor checkpoint files were found.") diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 7366845c4d78..6a27b6b5e0fb 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -3451,3 +3451,67 @@ def test_vision_language_model(self): assert image_encoder is model.model.vision_tower, ( f"LLaVA get_encoder(modality='image') should return vision_tower, got {type(image_encoder)}" ) + + +@require_torch +class DisableMmapLoadingTest(unittest.TestCase): + """Tests for the `disable_mmap` kwarg in `load_state_dict` and the `_is_on_hf_mount` helper.""" + + def _fake_open_factory(self, proc_mounts_contents): + """Return a patched `open` that serves `proc_mounts_contents` for `/proc/mounts` and defers otherwise.""" + import builtins + + real_open = builtins.open + + def fake_open(path, *args, **kwargs): + if path == "/proc/mounts": + import io + + return io.StringIO(proc_mounts_contents) + return real_open(path, *args, **kwargs) + + return fake_open + + def test_is_on_hf_mount_linux_match(self): + from transformers.modeling_utils import _is_on_hf_mount + + mounts = ( + "proc /proc proc rw,nosuid,nodev,noexec,relatime 0 0\n" + "hf-mount /data fuse.hf-mount rw,nosuid,nodev,relatime,user_id=0 0 0\n" + ) + with patch("sys.platform", "linux"), patch("builtins.open", self._fake_open_factory(mounts)): + self.assertTrue(_is_on_hf_mount("/data/model.safetensors")) + + def test_is_on_hf_mount_no_match(self): + from transformers.modeling_utils import _is_on_hf_mount + + mounts = "proc /proc proc rw,nosuid,nodev,noexec,relatime 0 0\n/dev/nvme0n1p1 /data ext4 rw,relatime 0 0\n" + with patch("sys.platform", "linux"), patch("builtins.open", self._fake_open_factory(mounts)): + self.assertFalse(_is_on_hf_mount("/data/model.safetensors")) + + def test_is_on_hf_mount_non_linux(self): + from transformers.modeling_utils import _is_on_hf_mount + + with patch("sys.platform", "darwin"): + self.assertFalse(_is_on_hf_mount("/data/model.safetensors")) + + def test_load_state_dict_disable_mmap_explicit(self): + import torch + from safetensors.torch import save_file as safe_save_file + + from transformers.modeling_utils import load_state_dict + + state_dict = { + "weight": torch.arange(12, dtype=torch.float32).reshape(3, 4), + "bias": torch.tensor([1.0, 2.0, 3.0]), + } + with tempfile.TemporaryDirectory() as tmpdir: + ckpt_path = os.path.join(tmpdir, "model.safetensors") + safe_save_file(state_dict, ckpt_path) + + loaded_mmap = load_state_dict(ckpt_path, disable_mmap=False) + loaded_no_mmap = load_state_dict(ckpt_path, disable_mmap=True) + + self.assertEqual(set(loaded_mmap.keys()), set(loaded_no_mmap.keys())) + for k in loaded_mmap: + torch.testing.assert_close(loaded_mmap[k], loaded_no_mmap[k])