Skip to content

Add GDS support for safetensors loading #45113

Open
cyyever wants to merge 1 commit intohuggingface:mainfrom
cyyever:gds-support
Open

Add GDS support for safetensors loading #45113
cyyever wants to merge 1 commit intohuggingface:mainfrom
cyyever:gds-support

Conversation

@cyyever
Copy link
Copy Markdown
Contributor

@cyyever cyyever commented Mar 30, 2026

What does this PR do?

This PR adds GPU Direct Storage (GDS) support for safetensors model loading via torch.cuda.gds.GdsFile. GDS is disabled by default, HF_ENABLE_GDS=1 env is used to enable it.

Benchmark

A100 PCIe 40GB, Samsung NVMe 3.5TB, GDS compat mode (no nvidia-fs) in from_pretrained:

Model safe_open GDS Speedup
Qwen/Qwen2.5-0.5B 1.011s 1.027s 0.98x
Qwen/Qwen2.5-7B 3.133s 1.725s 1.82x

Code Agent Policy

The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by
code agents. We are currently bottlenecked by our ability to review and respond to them. As a result,
we ask that new users do not submit pure code agent PRs at this time.
You may use code agents in drafting or to help you diagnose issues. We'd also ask autonomous "OpenClaw"-like agents
not to open any PRs or issues for the moment.

PRs that appear to be fully agent-written will probably be closed without review, and we may block users who do this
repeatedly or maliciously.

This is a rapidly-evolving situation that's causing significant shockwaves in the open-source community. As a result,
this policy is likely to be updated regularly in the near future. For more information, please read CONTRIBUTING.md.

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@cyyever cyyever force-pushed the gds-support branch 7 times, most recently from 430ea29 to e38eb51 Compare March 30, 2026 11:17
@cyyever cyyever marked this pull request as draft March 30, 2026 12:53
@cyyever cyyever marked this pull request as ready for review March 30, 2026 13:24
Copy link
Copy Markdown
Member

@McPatate McPatate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi thanks for the contribution! We're working on these kinds of topics over in the https://github.com/huggingface/safetensors repo directly, probably would be best to have GDS support directly in lib rather than transformers, cc @ArthurZucker.

Would be curious to see the following:

  • larger model load
  • distributed loading
  • running iostat during load to measure throughput
  • fio theoritical max throughput on your machine for reference
  • warm vs cold cache test (sudo sync && echo 3 | sudo tee /proc/sys/vm/drop_caches)
  • OS, I assume linux here?
  • I assume each issued read is not sequential since there might be overlap between tensor ranges as you're pulling full tensor data for each slice, IIUC, wonder how much it impacts reads in the context of GDS though. Would be cool if you could test that but it'll require, I think, a more involved setup.

Not sure this will scale well with larger models, lmk if you want to test some more after my feedback. IMO not worth it as we're going to tackle this in the coming weeks in a "specialised manner" over in https://github.com/mfuntowicz/hmll, cc @mfuntowicz

return self._shape

def __getitem__(self, slices):
tensor = self._gds_file.get_tensor(self._name, self._target_device)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean you pull the full tensor for each slice that is requested? What does your memory footprint on device look like once the model is loaded?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure your question. The purpose is to load tensors via GDS api, which works best with aligned file offsets.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you guarantee file offsets are aligned here? safetensors files aren't written with that constraint in mind, you need to do some extra processing (we're thinking of supporting writing aligned offsets, but it's tricky wrt backwards compatibility).

What I'm asking, is that from your implementation, it seems you're loading the full tensor self._name on each call to GdsSlice.__getitem__. That is why I asked what the memory footprint (total used memory on device) looks like. If you can run nvidia-smi after loading the model, that'd be a good test to see if that happens.

@cyyever
Copy link
Copy Markdown
Contributor Author

cyyever commented Mar 31, 2026

@McPatate I saw a similar PR of safetensor repo but unfortunately it was denied. I prefer to apply GDS in transformers because it provides global overview of IO bottlenecks in large-scale LLM training/inference scenarios.
This PR is also easier to review because PyTorch already provides GDS primitives, while we have to write rust code to encapsulate cuFile API in safe tensor.

For your concerns:

  1. larger model load => provided.
  2. distributed loading => we can load to different GPU cards on the same host via mapping
  3. running iostat during load to measure throughput => I run benchmarks on a working server with other processes were training models, it was hard to isolate the workset.
  4. fio theoritical max throughput on your machine for reference => it depends on the specification of tested SSDs, filesystem settings (we use EXT4 with noatime) and other kernel configs.
  5. The reported figures are from warm cache test. Code cache tests have less sense on real servers where multiple users share a comment mount point to store downloaded LLMs.
  6. GDS only supports Linux.
  7. sequential IO access dominates the benchmark due to tensor disk layout, especially when loading checkpoints of LLMs.
    Could you help verify my findings using the benchmark code?
#!/usr/bin/env python3
"""Benchmark GDS vs safe_open for from_pretrained loading.

Usage:
    HF_ENABLE_GDS=1 python benchmarks/benchmark_gds.py
    HF_ENABLE_GDS=1 python benchmarks/benchmark_gds.py --models Qwen/Qwen2.5-7B
    HF_ENABLE_GDS=1 python benchmarks/benchmark_gds.py --models Qwen/Qwen2.5-7B --device 2
"""

import argparse
import gc
import os
import pathlib
import statistics
import time

import torch


def _get_storage_info() -> str:
    """Return a one-line description of the storage device under the HF cache."""
    try:
        cache_dir = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
        # Resolve to real mount point
        st = os.stat(cache_dir)
        dev_major, dev_minor = os.major(st.st_dev), os.minor(st.st_dev)
        # Find the block device name from /sys/dev/block
        sys_path = pathlib.Path(f"/sys/dev/block/{dev_major}:{dev_minor}")
        if sys_path.is_symlink():
            block_name = sys_path.resolve().name
        else:
            block_name = None
        # Walk up to the whole-disk device
        disk_name = block_name
        if not pathlib.Path(f"/sys/block/{disk_name}").exists():
            # e.g. nvme0n1p1 -> nvme0n1, sda2 -> sda
            parent = pathlib.Path(f"/sys/class/block/{block_name}").resolve().parent.name
            if pathlib.Path(f"/sys/block/{parent}").exists():
                disk_name = parent
        model = "unknown"
        for p in [f"/sys/block/{disk_name}/device/model", f"/sys/class/block/{disk_name}/device/model"]:
            if os.path.isfile(p):
                model = open(p).read().strip()
                break
        rotational = "HDD" if open(f"/sys/block/{disk_name}/queue/rotational").read().strip() == "1" else "SSD/NVMe"
        return f"{model} ({rotational}, /dev/{block_name})"
    except Exception:
        return "unknown"


def bench_from_pretrained(model_id: str, device: str, use_gds: bool, warmup: int = 1, repeats: int = 3):
    from transformers import AutoModelForCausalLM

    os.environ["HF_ENABLE_GDS"] = "1" if use_gds else "0"
    times = []
    for i in range(warmup + repeats):
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        start = time.perf_counter()
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, dtype=torch.float16)
        torch.cuda.synchronize()
        elapsed = time.perf_counter() - start
        if i >= warmup:
            times.append(elapsed)
        del model
        gc.collect()
        torch.cuda.empty_cache()
    os.environ.pop("HF_ENABLE_GDS", None)
    return times


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--models", nargs="+", default=["Qwen/Qwen2.5-0.5B", "meta-llama/Llama-3.2-1B"])
    parser.add_argument("--device", type=int, default=0)
    parser.add_argument("--warmup", type=int, default=1)
    parser.add_argument("--repeats", type=int, default=3)
    args = parser.parse_args()

    device = f"cuda:{args.device}"
    print(f"PyTorch {torch.__version__}, CUDA {torch.version.cuda}, {torch.cuda.get_device_name(args.device)}")
    print(f"Storage: {_get_storage_info()}")

    from transformers.utils.gds_io import is_gds_available

    print(f"GDS available: {is_gds_available()}")
    print()
    print("| Model | safe_open | GDS | Speedup |")
    print("|-------|-----------|-----|---------|")

    for model_id in args.models:
        t_b = bench_from_pretrained(model_id, device, False, args.warmup, args.repeats)
        t_g = bench_from_pretrained(model_id, device, True, args.warmup, args.repeats)
        m_b, m_g = statistics.mean(t_b), statistics.mean(t_g)
        print(f"| {model_id} | {m_b:.3f}s | {m_g:.3f}s | **{m_b / m_g:.2f}x** |")


if __name__ == "__main__":
    main()

@cyyever cyyever requested a review from McPatate March 31, 2026 14:45
@McPatate
Copy link
Copy Markdown
Member

larger model load => provided.
distributed loading => we can load to different GPU cards on the same host via mapping

For larger I meant, larger than qwen 7b! But I assume we're entering in distributed territory after that size, so consider these two points as the same. I would appreciate a benchmark to see how you're impl is performing in that scenario.

I run benchmarks on a working server with other processes were training models, it was hard to isolate the workset.

No way to run your code on an isolated machine?

fio theoritical max throughput on your machine for reference => it depends on the specification of tested SSDs, filesystem settings (we use EXT4 with noatime) and other kernel configs.

I was asking if you could run fio to have a reference point of "theoritical disk throughput speed" vs "actual throughput of GDS loading".

The reported figures are from warm cache test. Code cache tests have less sense on real servers where multiple users share a comment mount point to store downloaded LLMs.

There are scenarios where cold cache runs make sense (e.g. loading a model after a restart with files already present on disk).

sequential IO access dominates the benchmark due to tensor disk layout, especially when loading checkpoints of LLMs.

I'm not convinced you are issuing reads sequentially, as for each slice you read the full tensor, again IIUC.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants