From d588be5c4e4e25dfa4c83539b5341f5d7054960b Mon Sep 17 00:00:00 2001 From: Gideon Giffard <118290024+fardeon@users.noreply.github.com> Date: Fri, 21 Apr 2023 10:28:47 +0800 Subject: [PATCH] build(utils): allow to download safetensors format --- utils/download.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/utils/download.py b/utils/download.py index 7dbb39f7..6168a637 100644 --- a/utils/download.py +++ b/utils/download.py @@ -3,6 +3,7 @@ This allows memory-constrained CI/CD runners to build container images with large bundled models. See ../deployments/bundle/ for examples. """ +import os import sys import tempfile @@ -11,6 +12,11 @@ if len(sys.argv) < 3: sys.exit("usage: python download.py REPO_ID LOCAL_DIR [REVISION]") +if os.getenv("TENSOR_FORMAT") == "safetensors": + tensor_format = "*.safetensors" +else: + tensor_format = "*.bin" + with tempfile.TemporaryDirectory() as cache_dir: huggingface_hub.snapshot_download( repo_id=sys.argv[1], @@ -19,5 +25,5 @@ cache_dir=cache_dir, local_dir_use_symlinks=False, resume_download=True, - allow_patterns=["*.bin", "*.json", "*.model", "*.py"], + allow_patterns=[tensor_format, "*.json", "*.model", "*.py"], )