diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py
index 34bba89e2e07..721f13a2f314 100644
--- a/src/diffusers/configuration_utils.py
+++ b/src/diffusers/configuration_utils.py
@@ -89,6 +89,7 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
self.to_json_file(output_config_file)
logger.info(f"ConfigMixinuration saved in {output_config_file}")
+
@classmethod
def get_config_dict(
@@ -182,35 +183,42 @@ def get_config_dict(
logger.info(f"loading configuration file {config_file}")
else:
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
+
+ return config_dict
+ @classmethod
+ def extract_init_dict(cls, config_dict, **kwargs):
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
expected_keys.remove("self")
-
+ init_dict = {}
for key in expected_keys:
if key in kwargs:
# overwrite key
- config_dict[key] = kwargs.pop(key)
+ init_dict[key] = kwargs.pop(key)
+ elif key in config_dict:
+ # use value from config dict
+ init_dict[key] = config_dict.pop(key)
- passed_keys = set(config_dict.keys())
-
- unused_kwargs = kwargs
- for key in passed_keys - expected_keys:
- unused_kwargs[key] = config_dict.pop(key)
+ unused_kwargs = config_dict.update(kwargs)
+
+ passed_keys = set(init_dict.keys())
if len(expected_keys - passed_keys) > 0:
logger.warn(
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
)
- return config_dict, unused_kwargs
+ return init_dict, unused_kwargs
@classmethod
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
- config_dict, unused_kwargs = cls.get_config_dict(
+ config_dict = cls.get_config_dict(
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
)
- model = cls(**config_dict)
+ init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
+
+ model = cls(**init_dict)
if return_unused_kwargs:
return model, unused_kwargs
diff --git a/src/diffusers/dynamic_modules_utils.py b/src/diffusers/dynamic_modules_utils.py
new file mode 100644
index 000000000000..a433c2090ab9
--- /dev/null
+++ b/src/diffusers/dynamic_modules_utils.py
@@ -0,0 +1,339 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities to dynamically load objects from the Hub."""
+
+import importlib
+import os
+import re
+import shutil
+import sys
+from pathlib import Path
+from typing import Dict, Optional, Union
+
+from huggingface_hub import HfFolder, model_info
+
+from transformers.utils import (
+ HF_MODULES_CACHE,
+ TRANSFORMERS_DYNAMIC_MODULE_NAME,
+ cached_path,
+ hf_bucket_url,
+ is_offline_mode,
+ logging,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def init_hf_modules():
+ """
+ Creates the cache directory for modules with an init, and adds it to the Python path.
+ """
+ # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
+ if HF_MODULES_CACHE in sys.path:
+ return
+
+ sys.path.append(HF_MODULES_CACHE)
+ os.makedirs(HF_MODULES_CACHE, exist_ok=True)
+ init_path = Path(HF_MODULES_CACHE) / "__init__.py"
+ if not init_path.exists():
+ init_path.touch()
+
+
+def create_dynamic_module(name: Union[str, os.PathLike]):
+ """
+ Creates a dynamic module in the cache directory for modules.
+ """
+ init_hf_modules()
+ dynamic_module_path = Path(HF_MODULES_CACHE) / name
+ # If the parent module does not exist yet, recursively create it.
+ if not dynamic_module_path.parent.exists():
+ create_dynamic_module(dynamic_module_path.parent)
+ os.makedirs(dynamic_module_path, exist_ok=True)
+ init_path = dynamic_module_path / "__init__.py"
+ if not init_path.exists():
+ init_path.touch()
+
+
+def get_relative_imports(module_file):
+ """
+ Get the list of modules that are relatively imported in a module file.
+
+ Args:
+ module_file (`str` or `os.PathLike`): The module file to inspect.
+ """
+ with open(module_file, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ # Imports of the form `import .xxx`
+ relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
+ # Imports of the form `from .xxx import yyy`
+ relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
+ # Unique-ify
+ return list(set(relative_imports))
+
+
+def get_relative_import_files(module_file):
+ """
+ Get the list of all files that are needed for a given module. Note that this function recurses through the relative
+ imports (if a imports b and b imports c, it will return module files for b and c).
+
+ Args:
+ module_file (`str` or `os.PathLike`): The module file to inspect.
+ """
+ no_change = False
+ files_to_check = [module_file]
+ all_relative_imports = []
+
+ # Let's recurse through all relative imports
+ while not no_change:
+ new_imports = []
+ for f in files_to_check:
+ new_imports.extend(get_relative_imports(f))
+
+ module_path = Path(module_file).parent
+ new_import_files = [str(module_path / m) for m in new_imports]
+ new_import_files = [f for f in new_import_files if f not in all_relative_imports]
+ files_to_check = [f"{f}.py" for f in new_import_files]
+
+ no_change = len(new_import_files) == 0
+ all_relative_imports.extend(files_to_check)
+
+ return all_relative_imports
+
+
+def check_imports(filename):
+ """
+ Check if the current Python environment contains all the libraries that are imported in a file.
+ """
+ with open(filename, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ # Imports of the form `import xxx`
+ imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
+ # Imports of the form `from xxx import yyy`
+ imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
+ # Only keep the top-level module
+ imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
+
+ # Unique-ify and test we got them all
+ imports = list(set(imports))
+ missing_packages = []
+ for imp in imports:
+ try:
+ importlib.import_module(imp)
+ except ImportError:
+ missing_packages.append(imp)
+
+ if len(missing_packages) > 0:
+ raise ImportError(
+ "This modeling file requires the following packages that were not found in your environment: "
+ f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
+ )
+
+ return get_relative_imports(filename)
+
+
+def get_class_in_module(class_name, module_path):
+ """
+ Import a module on the cache directory for modules and extract a class from it.
+ """
+ module_path = module_path.replace(os.path.sep, ".")
+ module = importlib.import_module(module_path)
+ return getattr(module, class_name)
+
+
+def get_cached_module_file(
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ module_file: str,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: bool = False,
+ proxies: Optional[Dict[str, str]] = None,
+ use_auth_token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+):
+ """
+ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
+ Transformers module.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ module_file (`str`):
+ The name of the module file containing the class to look for.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+
+
+
+ Passing `use_auth_token=True` is required when you want to use a private model.
+
+
+
+ Returns:
+ `str`: The path to the module inside the cache.
+ """
+ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
+ submodule = "local"
+
+ try:
+ # Load from URL or cache if already cached
+ resolved_module_file = cached_path(
+ module_file_or_url,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ )
+
+ except EnvironmentError:
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
+ raise
+
+ # Check we have all the requirements in our environment
+ modules_needed = check_imports(resolved_module_file)
+
+ # Now we move the module inside our cached dynamic modules.
+ full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
+ create_dynamic_module(full_submodule)
+ submodule_path = Path(HF_MODULES_CACHE) / full_submodule
+ # We always copy local files (we could hash the file to see if there was a change, and give them the name of
+ # that hash, to only copy when there is a modification but it seems overkill for now).
+ # The only reason we do the copy is to avoid putting too many folders in sys.path.
+ shutil.copy(resolved_module_file, submodule_path / module_file)
+ for module_needed in modules_needed:
+ module_needed = f"{module_needed}.py"
+ shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
+ return os.path.join(full_submodule, module_file)
+
+
+def get_class_from_dynamic_module(
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ module_file: str,
+ class_name: str,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: bool = False,
+ proxies: Optional[Dict[str, str]] = None,
+ use_auth_token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+ **kwargs,
+):
+ """
+ Extracts a class from a module file, present in the local folder or repository of a model.
+
+
+
+ Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
+ therefore only be called on trusted repos.
+
+
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ module_file (`str`):
+ The name of the module file containing the class to look for.
+ class_name (`str`):
+ The name of the class to import in the module.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ use_auth_token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+
+
+
+ Passing `use_auth_token=True` is required when you want to use a private model.
+
+
+
+ Returns:
+ `type`: The class, dynamically imported from the module.
+
+ Examples:
+
+ ```python
+ # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
+ # module.
+ cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
+ ```"""
+ # And lastly we get the class inside our newly created module
+ final_module = get_cached_module_file(
+ pretrained_model_name_or_path,
+ module_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ local_files_only=local_files_only,
+ )
+ return get_class_in_module(class_name, final_module.replace(".py", ""))
diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py
index 2749ad68b7a0..1d9a2fd98971 100644
--- a/src/diffusers/pipeline_utils.py
+++ b/src/diffusers/pipeline_utils.py
@@ -16,6 +16,7 @@
import importlib
import os
+from pathlib import Path
from typing import Optional, Union
from huggingface_hub import snapshot_download
@@ -23,6 +24,7 @@
from transformers.utils import logging
from .configuration_utils import ConfigMixin
+from .dynamic_modules_utils import get_class_from_dynamic_module
INDEX_FILE = "diffusion_model.pt"
@@ -54,19 +56,23 @@ def register_modules(self, **kwargs):
class_name = module.__class__.__name__
register_dict = {name: (library, class_name)}
- register_dict["_module"] = self.__module__
+
# save model index config
self.register(**register_dict)
# set models
setattr(self, name, module)
+
+ register_dict = {"_module" : self.__module__.split(".")[-1] + ".py"}
+ self.register(**register_dict)
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
self.save_config(save_directory)
model_index_dict = self._dict_to_save
model_index_dict.pop("_class_name")
+ model_index_dict.pop("_module")
for name, (library_name, class_name) in self._dict_to_save.items():
importable_classes = LOADABLE_CLASSES[library_name]
@@ -95,16 +101,22 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else:
cached_folder = pretrained_model_name_or_path
- config_dict, pipeline_kwargs = cls.get_config_dict(cached_folder)
+ config_dict = cls.get_config_dict(cached_folder)
+
+ module = config_dict["_module"]
+ class_name_ = config_dict["_class_name"]
+
+ if class_name_ == cls.__name__:
+ pipeline_class = cls
+ else:
+ pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
+
- module = pipeline_kwargs.pop("_module", None)
- # TODO(Suraj) - make from hub import work
- # Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work
- # Add Sylvains code from transformers
+ init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {}
- for name, (library_name, class_name) in config_dict.items():
+ for name, (library_name, class_name) in init_dict.items():
importable_classes = LOADABLE_CLASSES[library_name]
if library_name == module:
@@ -129,5 +141,5 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
- model = cls(**init_kwargs)
+ model = pipeline_class(**init_kwargs)
return model
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py
index 6dce91ae4b3a..a1c6079ce24e 100755
--- a/tests/test_modeling_utils.py
+++ b/tests/test_modeling_utils.py
@@ -22,6 +22,8 @@
import torch
from diffusers import GaussianDDPMScheduler, UNetModel
+from diffusers.pipeline_utils import DiffusionPipeline
+from models.vision.ddpm.modeling_ddpm import DDPM
global_rng = random.Random()
@@ -199,3 +201,46 @@ def test_sample_fast(self):
assert image.shape == (1, 3, 256, 256)
image_slice = image[0, -1, -3:, -3:].cpu()
assert (image_slice - torch.tensor([[0.1746, 0.5125, -0.7920], [-0.5734, -0.2910, -0.1984], [0.4090, -0.7740, -0.3941]])).abs().sum() < 1e-3
+
+
+class PipelineTesterMixin(unittest.TestCase):
+ def test_from_pretrained_save_pretrained(self):
+ # 1. Load models
+ model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
+ schedular = GaussianDDPMScheduler(timesteps=10)
+
+ ddpm = DDPM(model, schedular)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ ddpm.save_pretrained(tmpdirname)
+ new_ddpm = DDPM.from_pretrained(tmpdirname)
+
+ generator = torch.Generator()
+ generator = generator.manual_seed(669472945848556)
+
+ image = ddpm(generator=generator)
+ generator = generator.manual_seed(669472945848556)
+ new_image = new_ddpm(generator=generator)
+
+ assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
+
+
+ @slow
+ def test_from_pretrained_hub(self):
+ model_path = "fusing/ddpm-cifar10"
+
+ ddpm = DDPM.from_pretrained(model_path)
+ ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
+
+ ddpm.noise_scheduler.num_timesteps = 10
+ ddpm_from_hub.noise_scheduler.num_timesteps = 10
+
+
+ generator = torch.Generator(device=torch_device)
+ generator = generator.manual_seed(669472945848556)
+
+ image = ddpm(generator=generator)
+ generator = generator.manual_seed(669472945848556)
+ new_image = ddpm_from_hub(generator=generator)
+
+ assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"