Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 48 additions & 4 deletions src/transformers/dynamic_module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,42 @@ def get_class_in_module(
return getattr(module, class_name)


def _compute_local_source_files_hash(
pretrained_model_name_or_path: str | os.PathLike,
module_file: str | os.PathLike,
resolved_module_file: str | os.PathLike,
modules_needed: list[str],
) -> str:
"""
Computes a stable hash from the bytes of the local source file and its direct relative-import source files.
"""
model_path = Path(pretrained_model_name_or_path).resolve()
module_parent = Path(module_file).parent

resolved_module_file = Path(resolved_module_file).resolve()

def _resolve_relative_source_path(source_file_path: Path) -> str:
try:
return source_file_path.relative_to(model_path).as_posix()
except ValueError:
# Fallback for edge cases where the source file is not under the local model directory.
return source_file_path.as_posix()

files_to_hash = [
(_resolve_relative_source_path(resolved_module_file), resolved_module_file),
]
for module_needed in modules_needed:
module_needed_path = (model_path / module_parent / f"{module_needed}.py").resolve()
files_to_hash.append((_resolve_relative_source_path(module_needed_path), module_needed_path))

source_files_hash = hashlib.sha256()
for relative_path, file_path in sorted(files_to_hash, key=lambda entry: entry[0]):
source_files_hash.update(relative_path.encode("utf-8"))
source_files_hash.update(file_path.read_bytes())

return source_files_hash.hexdigest()[:16]


def get_cached_module_file(
pretrained_model_name_or_path: str | os.PathLike,
module_file: str,
Expand Down Expand Up @@ -376,9 +412,8 @@ def get_cached_module_file(
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if is_local:
submodule = _sanitize_module_name(os.path.basename(pretrained_model_name_or_path))
else:
cached_module = None
if not is_local:
submodule = os.path.sep.join(map(_sanitize_module_name, pretrained_model_name_or_path.split("/")))
cached_module = try_to_load_from_cache(
pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
Expand Down Expand Up @@ -408,12 +443,21 @@ def get_cached_module_file(

# Check we have all the requirements in our environment
modules_needed = check_imports(resolved_module_file)
if is_local:
local_model_name = _sanitize_module_name(os.path.basename(os.path.normpath(pretrained_model_name_or_path)))
local_source_files_hash = _compute_local_source_files_hash(
pretrained_model_name_or_path, module_file, resolved_module_file, modules_needed
)
if local_model_name:
submodule = os.path.sep.join([local_model_name, local_source_files_hash])
else:
submodule = local_source_files_hash

# Now we move the module inside our cached dynamic modules.
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
create_dynamic_module(full_submodule)
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
if submodule == _sanitize_module_name(os.path.basename(pretrained_model_name_or_path)):
if is_local:
# We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or
# has changed since last copy.
if not (submodule_path / module_file).exists() or not filecmp.cmp(
Expand Down
74 changes: 73 additions & 1 deletion tests/utils/test_dynamic_module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# limitations under the License.

import os
from pathlib import Path

import pytest

from transformers.dynamic_module_utils import get_imports
from transformers import dynamic_module_utils
from transformers.dynamic_module_utils import get_cached_module_file, get_imports


TOP_LEVEL_IMPORT = """
Expand Down Expand Up @@ -127,3 +129,73 @@ def test_import_parsing(tmp_path, case):

parsed_imports = get_imports(tmp_file_path)
assert parsed_imports == ["os"]


def _create_local_module(module_dir: Path, module_code: str, helper_code: str | None = None):
module_dir.mkdir(parents=True, exist_ok=True)
(module_dir / "custom_model.py").write_text(module_code, encoding="utf-8")
if helper_code is not None:
(module_dir / "helper.py").write_text(helper_code, encoding="utf-8")


def test_get_cached_module_file_local_cache_key_uses_basename_and_content_hash(monkeypatch, tmp_path):
modules_cache = tmp_path / "hf_modules_cache"
monkeypatch.setattr(dynamic_module_utils, "HF_MODULES_CACHE", str(modules_cache))

model_dir_a = tmp_path / "pretrained_a" / "subdir"
model_dir_b = tmp_path / "pretrained_b" / "subdir"
model_dir_c = tmp_path / "pretrained_c" / "subdir"

_create_local_module(model_dir_a, 'MAGIC = "A"\n')
_create_local_module(model_dir_b, 'MAGIC = "B"\n')
_create_local_module(model_dir_c, 'MAGIC = "A"\n')

cached_module_a = get_cached_module_file(str(model_dir_a), "custom_model.py")
cached_module_b = get_cached_module_file(str(model_dir_b), "custom_model.py")
cached_module_c = get_cached_module_file(str(model_dir_c), "custom_model.py")

cached_module_path_a = Path(cached_module_a)
assert cached_module_path_a.parent.parent.name == "subdir"
assert len(cached_module_path_a.parent.name) == 16
assert cached_module_a != cached_module_b
assert cached_module_a == cached_module_c


def test_get_cached_module_file_local_cache_key_includes_relative_import_sources(monkeypatch, tmp_path):
modules_cache = tmp_path / "hf_modules_cache"
monkeypatch.setattr(dynamic_module_utils, "HF_MODULES_CACHE", str(modules_cache))

model_dir_a = tmp_path / "pretrained_a" / "subdir"
model_dir_b = tmp_path / "pretrained_b" / "subdir"

module_code = "from .helper import MAGIC\nVALUE = MAGIC\n"
_create_local_module(model_dir_a, module_code, 'MAGIC = "A"\n')
_create_local_module(model_dir_b, module_code, 'MAGIC = "B"\n')

cached_module_a = get_cached_module_file(str(model_dir_a), "custom_model.py")
cached_module_b = get_cached_module_file(str(model_dir_b), "custom_model.py")

cached_helper_a = modules_cache / Path(cached_module_a).parent / "helper.py"
cached_helper_b = modules_cache / Path(cached_module_b).parent / "helper.py"

assert cached_module_a != cached_module_b
assert cached_helper_a.read_text(encoding="utf-8") == 'MAGIC = "A"\n'
assert cached_helper_b.read_text(encoding="utf-8") == 'MAGIC = "B"\n'


def test_get_cached_module_file_local_cache_key_keeps_hash_stable_with_different_basenames(monkeypatch, tmp_path):
modules_cache = tmp_path / "hf_modules_cache"
monkeypatch.setattr(dynamic_module_utils, "HF_MODULES_CACHE", str(modules_cache))

model_dir_a = tmp_path / "pretrained_a" / "alpha_subdir"
model_dir_b = tmp_path / "pretrained_b" / "beta_subdir"

_create_local_module(model_dir_a, 'MAGIC = "A"\n')
_create_local_module(model_dir_b, 'MAGIC = "A"\n')

cached_module_a = Path(get_cached_module_file(str(model_dir_a), "custom_model.py"))
cached_module_b = Path(get_cached_module_file(str(model_dir_b), "custom_model.py"))

assert cached_module_a.parent.parent.name == "alpha_subdir"
assert cached_module_b.parent.parent.name == "beta_subdir"
assert cached_module_a.parent.name == cached_module_b.parent.name
Loading