From ce5666211e0c0aef498fa4459302e2402c130470 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 7 Jun 2022 13:56:09 +0200 Subject: [PATCH 1/4] make from hub import work --- models/vision/ddpm/modeling_ddpm.py | 2 +- src/diffusers/dynamic_modules_utils.py | 339 +++++++++++++++++++++++++ src/diffusers/pipeline_utils.py | 13 +- 3 files changed, 347 insertions(+), 7 deletions(-) create mode 100644 src/diffusers/dynamic_modules_utils.py diff --git a/models/vision/ddpm/modeling_ddpm.py b/models/vision/ddpm/modeling_ddpm.py index 4a3f0b24b72e..ae049a8c0acf 100644 --- a/models/vision/ddpm/modeling_ddpm.py +++ b/models/vision/ddpm/modeling_ddpm.py @@ -23,7 +23,7 @@ class DDPM(DiffusionPipeline): modeling_file = "modeling_ddpm.py" - def __init__(self, unet, noise_scheduler, vqvae): + def __init__(self, unet, noise_scheduler): super().__init__() self.register_modules(unet=unet, noise_scheduler=noise_scheduler) diff --git a/src/diffusers/dynamic_modules_utils.py b/src/diffusers/dynamic_modules_utils.py new file mode 100644 index 000000000000..a433c2090ab9 --- /dev/null +++ b/src/diffusers/dynamic_modules_utils.py @@ -0,0 +1,339 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities to dynamically load objects from the Hub.""" + +import importlib +import os +import re +import shutil +import sys +from pathlib import Path +from typing import Dict, Optional, Union + +from huggingface_hub import HfFolder, model_info + +from transformers.utils import ( + HF_MODULES_CACHE, + TRANSFORMERS_DYNAMIC_MODULE_NAME, + cached_path, + hf_bucket_url, + is_offline_mode, + logging, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def init_hf_modules(): + """ + Creates the cache directory for modules with an init, and adds it to the Python path. + """ + # This function has already been executed if HF_MODULES_CACHE already is in the Python path. + if HF_MODULES_CACHE in sys.path: + return + + sys.path.append(HF_MODULES_CACHE) + os.makedirs(HF_MODULES_CACHE, exist_ok=True) + init_path = Path(HF_MODULES_CACHE) / "__init__.py" + if not init_path.exists(): + init_path.touch() + + +def create_dynamic_module(name: Union[str, os.PathLike]): + """ + Creates a dynamic module in the cache directory for modules. + """ + init_hf_modules() + dynamic_module_path = Path(HF_MODULES_CACHE) / name + # If the parent module does not exist yet, recursively create it. + if not dynamic_module_path.parent.exists(): + create_dynamic_module(dynamic_module_path.parent) + os.makedirs(dynamic_module_path, exist_ok=True) + init_path = dynamic_module_path / "__init__.py" + if not init_path.exists(): + init_path.touch() + + +def get_relative_imports(module_file): + """ + Get the list of modules that are relatively imported in a module file. + + Args: + module_file (`str` or `os.PathLike`): The module file to inspect. + """ + with open(module_file, "r", encoding="utf-8") as f: + content = f.read() + + # Imports of the form `import .xxx` + relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE) + # Imports of the form `from .xxx import yyy` + relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE) + # Unique-ify + return list(set(relative_imports)) + + +def get_relative_import_files(module_file): + """ + Get the list of all files that are needed for a given module. Note that this function recurses through the relative + imports (if a imports b and b imports c, it will return module files for b and c). + + Args: + module_file (`str` or `os.PathLike`): The module file to inspect. + """ + no_change = False + files_to_check = [module_file] + all_relative_imports = [] + + # Let's recurse through all relative imports + while not no_change: + new_imports = [] + for f in files_to_check: + new_imports.extend(get_relative_imports(f)) + + module_path = Path(module_file).parent + new_import_files = [str(module_path / m) for m in new_imports] + new_import_files = [f for f in new_import_files if f not in all_relative_imports] + files_to_check = [f"{f}.py" for f in new_import_files] + + no_change = len(new_import_files) == 0 + all_relative_imports.extend(files_to_check) + + return all_relative_imports + + +def check_imports(filename): + """ + Check if the current Python environment contains all the libraries that are imported in a file. + """ + with open(filename, "r", encoding="utf-8") as f: + content = f.read() + + # Imports of the form `import xxx` + imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) + # Imports of the form `from xxx import yyy` + imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) + # Only keep the top-level module + imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] + + # Unique-ify and test we got them all + imports = list(set(imports)) + missing_packages = [] + for imp in imports: + try: + importlib.import_module(imp) + except ImportError: + missing_packages.append(imp) + + if len(missing_packages) > 0: + raise ImportError( + "This modeling file requires the following packages that were not found in your environment: " + f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`" + ) + + return get_relative_imports(filename) + + +def get_class_in_module(class_name, module_path): + """ + Import a module on the cache directory for modules and extract a class from it. + """ + module_path = module_path.replace(os.path.sep, ".") + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +def get_cached_module_file( + pretrained_model_name_or_path: Union[str, os.PathLike], + module_file: str, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, +): + """ + Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached + Transformers module. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced + under a user or organization name, like `dbmdz/bert-base-german-cased`. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + module_file (`str`): + The name of the module file containing the class to look for. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + + + + Passing `use_auth_token=True` is required when you want to use a private model. + + + + Returns: + `str`: The path to the module inside the cache. + """ + # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) + submodule = "local" + + try: + # Load from URL or cache if already cached + resolved_module_file = cached_path( + module_file_or_url, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + ) + + except EnvironmentError: + logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") + raise + + # Check we have all the requirements in our environment + modules_needed = check_imports(resolved_module_file) + + # Now we move the module inside our cached dynamic modules. + full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule + create_dynamic_module(full_submodule) + submodule_path = Path(HF_MODULES_CACHE) / full_submodule + # We always copy local files (we could hash the file to see if there was a change, and give them the name of + # that hash, to only copy when there is a modification but it seems overkill for now). + # The only reason we do the copy is to avoid putting too many folders in sys.path. + shutil.copy(resolved_module_file, submodule_path / module_file) + for module_needed in modules_needed: + module_needed = f"{module_needed}.py" + shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) + return os.path.join(full_submodule, module_file) + + +def get_class_from_dynamic_module( + pretrained_model_name_or_path: Union[str, os.PathLike], + module_file: str, + class_name: str, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + **kwargs, +): + """ + Extracts a class from a module file, present in the local folder or repository of a model. + + + + Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should + therefore only be called on trusted repos. + + + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced + under a user or organization name, like `dbmdz/bert-base-german-cased`. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + module_file (`str`): + The name of the module file containing the class to look for. + class_name (`str`): + The name of the class to import in the module. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + + + + Passing `use_auth_token=True` is required when you want to use a private model. + + + + Returns: + `type`: The class, dynamically imported from the module. + + Examples: + + ```python + # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this + # module. + cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel") + ```""" + # And lastly we get the class inside our newly created module + final_module = get_cached_module_file( + pretrained_model_name_or_path, + module_file, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + use_auth_token=use_auth_token, + revision=revision, + local_files_only=local_files_only, + ) + return get_class_in_module(class_name, final_module.replace(".py", "")) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 6b56f7823210..e1b53d9e3f2d 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -16,6 +16,7 @@ import importlib import os +from pathlib import Path from typing import Optional, Union from huggingface_hub import snapshot_download @@ -23,6 +24,7 @@ from transformers.utils import logging from .configuration_utils import ConfigMixin +from .dynamic_modules_utils import get_class_from_dynamic_module INDEX_FILE = "diffusion_model.pt" @@ -91,12 +93,10 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): # use snapshot download here to get it working from from_pretrained cached_folder = snapshot_download(pretrained_model_name_or_path) - config_dict, pipeline_kwargs = cls.get_config_dict(cached_folder) + _, config_dict = cls.get_config_dict(cached_folder) - module = pipeline_kwargs["_module"] - # TODO(Suraj) - make from hub import work - # Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work - # Add Sylvains code from transformers + module = config_dict.pop("_module", None) + class_name_ = config_dict.pop("_class_name") init_kwargs = {} @@ -122,5 +122,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) - model = cls(**init_kwargs) + class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder) + model = class_obj(**init_kwargs) return model From fe99460b5fe9f6c6f5199b4ecc42ab48574e69ae Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 7 Jun 2022 14:26:20 +0200 Subject: [PATCH 2/4] update config dict logic --- src/diffusers/configuration_utils.py | 29 ++++++++++++++++++---------- src/diffusers/pipeline_utils.py | 15 +++++++------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 34bba89e2e07..ca61120ffad6 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -89,6 +89,7 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool self.to_json_file(output_config_file) logger.info(f"ConfigMixinuration saved in {output_config_file}") + @classmethod def get_config_dict( @@ -182,35 +183,43 @@ def get_config_dict( logger.info(f"loading configuration file {config_file}") else: logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}") + + return config_dict + @classmethod + def extract_init_dict(cls, config_dict, **kwargs): expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys()) expected_keys.remove("self") - + import ipdb; ipdb.set_trace() + init_dict = {} for key in expected_keys: if key in kwargs: # overwrite key - config_dict[key] = kwargs.pop(key) + init_dict[key] = kwargs.pop(key) + elif key in config_dict: + # use value from config dict + init_dict[key] = config_dict.pop(key) - passed_keys = set(config_dict.keys()) - - unused_kwargs = kwargs - for key in passed_keys - expected_keys: - unused_kwargs[key] = config_dict.pop(key) + unused_kwargs = config_dict.update(kwargs) + + passed_keys = set(init_dict.keys()) if len(expected_keys - passed_keys) > 0: logger.warn( f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." ) - return config_dict, unused_kwargs + return init_dict, unused_kwargs @classmethod def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): - config_dict, unused_kwargs = cls.get_config_dict( + config_dict = cls.get_config_dict( pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs ) - model = cls(**config_dict) + init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) + + model = cls(**init_dict) if return_unused_kwargs: return model, unused_kwargs diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 8037d121f750..ba3a823e262e 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -97,16 +97,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P else: cached_folder = pretrained_model_name_or_path - config_dict, pipeline_kwargs = cls.get_config_dict(cached_folder) + config_dict = cls.get_config_dict(cached_folder) + module = config_dict["_module"] + class_name_ = config_dict["_class_name"] + class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder) - module = pipeline_kwargs.pop("_module", None) - # TODO(Suraj) - make from hub import work - # Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work - # Add Sylvains code from transformers + init_dict, unused = class_obj.extract_init_dict(config_dict, **kwargs) + import ipdb; ipdb.set_trace() init_kwargs = {} - for name, (library_name, class_name) in config_dict.items(): + for name, (library_name, class_name) in init_dict.items(): importable_classes = LOADABLE_CLASSES[library_name] if library_name == module: @@ -131,6 +132,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) - class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder) + model = class_obj(**init_kwargs) return model From d8287fcd1d94f33df55b54e2e1c140c2ab15b444 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 7 Jun 2022 15:39:47 +0200 Subject: [PATCH 3/4] fix issues with loading, add test for pipeline --- src/diffusers/configuration_utils.py | 1 - src/diffusers/pipeline_utils.py | 20 +++++++++---- tests/__init__.py | 0 tests/test_modeling_utils.py | 45 ++++++++++++++++++++++++++++ 4 files changed, 59 insertions(+), 7 deletions(-) create mode 100644 tests/__init__.py diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index ca61120ffad6..721f13a2f314 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -190,7 +190,6 @@ def get_config_dict( def extract_init_dict(cls, config_dict, **kwargs): expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys()) expected_keys.remove("self") - import ipdb; ipdb.set_trace() init_dict = {} for key in expected_keys: if key in kwargs: diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index ba3a823e262e..1d9a2fd98971 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -56,19 +56,23 @@ def register_modules(self, **kwargs): class_name = module.__class__.__name__ register_dict = {name: (library, class_name)} - register_dict["_module"] = self.__module__ + # save model index config self.register(**register_dict) # set models setattr(self, name, module) + + register_dict = {"_module" : self.__module__.split(".")[-1] + ".py"} + self.register(**register_dict) def save_pretrained(self, save_directory: Union[str, os.PathLike]): self.save_config(save_directory) model_index_dict = self._dict_to_save model_index_dict.pop("_class_name") + model_index_dict.pop("_module") for name, (library_name, class_name) in self._dict_to_save.items(): importable_classes = LOADABLE_CLASSES[library_name] @@ -98,12 +102,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P cached_folder = pretrained_model_name_or_path config_dict = cls.get_config_dict(cached_folder) + module = config_dict["_module"] class_name_ = config_dict["_class_name"] - class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder) + + if class_name_ == cls.__name__: + pipeline_class = cls + else: + pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder) + - init_dict, unused = class_obj.extract_init_dict(config_dict, **kwargs) - import ipdb; ipdb.set_trace() + init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_kwargs = {} @@ -132,6 +141,5 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) - - model = class_obj(**init_kwargs) + model = pipeline_class(**init_kwargs) return model diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 6dce91ae4b3a..04e7ddc5ad09 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -22,6 +22,8 @@ import torch from diffusers import GaussianDDPMScheduler, UNetModel +from diffusers.pipeline_utils import DiffusionPipeline +from models.vision.ddpm.modeling_ddpm import DDPM global_rng = random.Random() @@ -199,3 +201,46 @@ def test_sample_fast(self): assert image.shape == (1, 3, 256, 256) image_slice = image[0, -1, -3:, -3:].cpu() assert (image_slice - torch.tensor([[0.1746, 0.5125, -0.7920], [-0.5734, -0.2910, -0.1984], [0.4090, -0.7740, -0.3941]])).abs().sum() < 1e-3 + + +class PipelineTesterMixin(unittest.TestCase): + def test_from_pretrained_save_pretrained(self): + # 1. Load models + model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32) + schedular = GaussianDDPMScheduler(timesteps=10) + + ddpm = DDPM(model, schedular) + + with tempfile.TemporaryDirectory() as tmpdirname: + ddpm.save_pretrained(tmpdirname) + new_ddpm = DDPM.from_pretrained(tmpdirname) + + generator = torch.Generator() + generator = generator.manual_seed(669472945848556) + + image = ddpm(generator) + generator = generator.manual_seed(669472945848556) + new_image = new_ddpm(generator) + + assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" + + + @slow + def test_from_pretrained_hub(self): + model_path = "fusing/ddpm-cifar10" + + ddpm = DDPM.from_pretrained(model_path) + ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path) + + ddpm.noise_scheduler.num_timesteps = 10 + ddpm_from_hub.noise_scheduler.num_timesteps = 10 + + + generator = torch.Generator(device=torch_device) + generator = generator.manual_seed(669472945848556) + + image = ddpm(generator) + generator = generator.manual_seed(669472945848556) + new_image = ddpm_from_hub(generator) + + assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" From 733546210e0007ea8066103ebc9f928b17dd747a Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 7 Jun 2022 15:43:08 +0200 Subject: [PATCH 4/4] fix tests --- tests/test_modeling_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 04e7ddc5ad09..a1c6079ce24e 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -218,9 +218,9 @@ def test_from_pretrained_save_pretrained(self): generator = torch.Generator() generator = generator.manual_seed(669472945848556) - image = ddpm(generator) + image = ddpm(generator=generator) generator = generator.manual_seed(669472945848556) - new_image = new_ddpm(generator) + new_image = new_ddpm(generator=generator) assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" @@ -239,8 +239,8 @@ def test_from_pretrained_hub(self): generator = torch.Generator(device=torch_device) generator = generator.manual_seed(669472945848556) - image = ddpm(generator) + image = ddpm(generator=generator) generator = generator.manual_seed(669472945848556) - new_image = ddpm_from_hub(generator) + new_image = ddpm_from_hub(generator=generator) assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"