diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index 9c9e7b929f6f..184e13a41c44 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -311,6 +311,42 @@ def get_class_in_module( return getattr(module, class_name) +def _compute_local_source_files_hash( + pretrained_model_name_or_path: str | os.PathLike, + module_file: str | os.PathLike, + resolved_module_file: str | os.PathLike, + modules_needed: list[str], +) -> str: + """ + Computes a stable hash from the bytes of the local source file and its direct relative-import source files. + """ + model_path = Path(pretrained_model_name_or_path).resolve() + module_parent = Path(module_file).parent + + resolved_module_file = Path(resolved_module_file).resolve() + + def _resolve_relative_source_path(source_file_path: Path) -> str: + try: + return source_file_path.relative_to(model_path).as_posix() + except ValueError: + # Fallback for edge cases where the source file is not under the local model directory. + return source_file_path.as_posix() + + files_to_hash = [ + (_resolve_relative_source_path(resolved_module_file), resolved_module_file), + ] + for module_needed in modules_needed: + module_needed_path = (model_path / module_parent / f"{module_needed}.py").resolve() + files_to_hash.append((_resolve_relative_source_path(module_needed_path), module_needed_path)) + + source_files_hash = hashlib.sha256() + for relative_path, file_path in sorted(files_to_hash, key=lambda entry: entry[0]): + source_files_hash.update(relative_path.encode("utf-8")) + source_files_hash.update(file_path.read_bytes()) + + return source_files_hash.hexdigest()[:16] + + def get_cached_module_file( pretrained_model_name_or_path: str | os.PathLike, module_file: str, @@ -376,9 +412,8 @@ def get_cached_module_file( # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. pretrained_model_name_or_path = str(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path) - if is_local: - submodule = _sanitize_module_name(os.path.basename(pretrained_model_name_or_path)) - else: + cached_module = None + if not is_local: submodule = os.path.sep.join(map(_sanitize_module_name, pretrained_model_name_or_path.split("/"))) cached_module = try_to_load_from_cache( pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type @@ -408,12 +443,21 @@ def get_cached_module_file( # Check we have all the requirements in our environment modules_needed = check_imports(resolved_module_file) + if is_local: + local_model_name = _sanitize_module_name(os.path.basename(os.path.normpath(pretrained_model_name_or_path))) + local_source_files_hash = _compute_local_source_files_hash( + pretrained_model_name_or_path, module_file, resolved_module_file, modules_needed + ) + if local_model_name: + submodule = os.path.sep.join([local_model_name, local_source_files_hash]) + else: + submodule = local_source_files_hash # Now we move the module inside our cached dynamic modules. full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule create_dynamic_module(full_submodule) submodule_path = Path(HF_MODULES_CACHE) / full_submodule - if submodule == _sanitize_module_name(os.path.basename(pretrained_model_name_or_path)): + if is_local: # We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or # has changed since last copy. if not (submodule_path / module_file).exists() or not filecmp.cmp( diff --git a/tests/utils/test_dynamic_module_utils.py b/tests/utils/test_dynamic_module_utils.py index dfdc63460cd3..50b2ae2c8b17 100644 --- a/tests/utils/test_dynamic_module_utils.py +++ b/tests/utils/test_dynamic_module_utils.py @@ -13,10 +13,12 @@ # limitations under the License. import os +from pathlib import Path import pytest -from transformers.dynamic_module_utils import get_imports +from transformers import dynamic_module_utils +from transformers.dynamic_module_utils import get_cached_module_file, get_imports TOP_LEVEL_IMPORT = """ @@ -127,3 +129,73 @@ def test_import_parsing(tmp_path, case): parsed_imports = get_imports(tmp_file_path) assert parsed_imports == ["os"] + + +def _create_local_module(module_dir: Path, module_code: str, helper_code: str | None = None): + module_dir.mkdir(parents=True, exist_ok=True) + (module_dir / "custom_model.py").write_text(module_code, encoding="utf-8") + if helper_code is not None: + (module_dir / "helper.py").write_text(helper_code, encoding="utf-8") + + +def test_get_cached_module_file_local_cache_key_uses_basename_and_content_hash(monkeypatch, tmp_path): + modules_cache = tmp_path / "hf_modules_cache" + monkeypatch.setattr(dynamic_module_utils, "HF_MODULES_CACHE", str(modules_cache)) + + model_dir_a = tmp_path / "pretrained_a" / "subdir" + model_dir_b = tmp_path / "pretrained_b" / "subdir" + model_dir_c = tmp_path / "pretrained_c" / "subdir" + + _create_local_module(model_dir_a, 'MAGIC = "A"\n') + _create_local_module(model_dir_b, 'MAGIC = "B"\n') + _create_local_module(model_dir_c, 'MAGIC = "A"\n') + + cached_module_a = get_cached_module_file(str(model_dir_a), "custom_model.py") + cached_module_b = get_cached_module_file(str(model_dir_b), "custom_model.py") + cached_module_c = get_cached_module_file(str(model_dir_c), "custom_model.py") + + cached_module_path_a = Path(cached_module_a) + assert cached_module_path_a.parent.parent.name == "subdir" + assert len(cached_module_path_a.parent.name) == 16 + assert cached_module_a != cached_module_b + assert cached_module_a == cached_module_c + + +def test_get_cached_module_file_local_cache_key_includes_relative_import_sources(monkeypatch, tmp_path): + modules_cache = tmp_path / "hf_modules_cache" + monkeypatch.setattr(dynamic_module_utils, "HF_MODULES_CACHE", str(modules_cache)) + + model_dir_a = tmp_path / "pretrained_a" / "subdir" + model_dir_b = tmp_path / "pretrained_b" / "subdir" + + module_code = "from .helper import MAGIC\nVALUE = MAGIC\n" + _create_local_module(model_dir_a, module_code, 'MAGIC = "A"\n') + _create_local_module(model_dir_b, module_code, 'MAGIC = "B"\n') + + cached_module_a = get_cached_module_file(str(model_dir_a), "custom_model.py") + cached_module_b = get_cached_module_file(str(model_dir_b), "custom_model.py") + + cached_helper_a = modules_cache / Path(cached_module_a).parent / "helper.py" + cached_helper_b = modules_cache / Path(cached_module_b).parent / "helper.py" + + assert cached_module_a != cached_module_b + assert cached_helper_a.read_text(encoding="utf-8") == 'MAGIC = "A"\n' + assert cached_helper_b.read_text(encoding="utf-8") == 'MAGIC = "B"\n' + + +def test_get_cached_module_file_local_cache_key_keeps_hash_stable_with_different_basenames(monkeypatch, tmp_path): + modules_cache = tmp_path / "hf_modules_cache" + monkeypatch.setattr(dynamic_module_utils, "HF_MODULES_CACHE", str(modules_cache)) + + model_dir_a = tmp_path / "pretrained_a" / "alpha_subdir" + model_dir_b = tmp_path / "pretrained_b" / "beta_subdir" + + _create_local_module(model_dir_a, 'MAGIC = "A"\n') + _create_local_module(model_dir_b, 'MAGIC = "A"\n') + + cached_module_a = Path(get_cached_module_file(str(model_dir_a), "custom_model.py")) + cached_module_b = Path(get_cached_module_file(str(model_dir_b), "custom_model.py")) + + assert cached_module_a.parent.parent.name == "alpha_subdir" + assert cached_module_b.parent.parent.name == "beta_subdir" + assert cached_module_a.parent.name == cached_module_b.parent.name