From c810d4827f535208e13ad3d8c20119d5cfdebc61 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 31 Aug 2023 18:10:25 +0530 Subject: [PATCH 01/11] better support offloading when side loading is enabled. --- src/diffusers/loaders.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 52970e48147d..1d39a2ff634e 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -45,6 +45,7 @@ if is_accelerate_available(): from accelerate import init_empty_weights + from accelerate.hooks import CpuOffload, remove_hook_from_module from accelerate.utils import set_module_tensor_to_device logger = logging.get_logger(__name__) @@ -763,6 +764,16 @@ def load_textual_inversion( f" `{self.load_textual_inversion.__name__}`" ) + # Remove any existing hooks. + for _, component in self.components.items(): + if isinstance(component, nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + remove_hook_from_module(component) + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) @@ -916,6 +927,12 @@ def load_textual_inversion( for token_id, embedding in token_ids_and_embeddings: self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding + # offload back + if is_model_cpu_offload: + self.enable_model_cpu_offload() + else: + self.enable_sequential_cpu_offload() + class LoraLoaderMixin: r""" @@ -946,6 +963,16 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di kwargs (`dict`, *optional*): See [`~loaders.LoraLoaderMixin.lora_state_dict`]. """ + # Remove any existing hooks. + for _, component in self.components.items(): + if isinstance(component, nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + remove_hook_from_module(component) + state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) self.load_lora_into_text_encoder( @@ -955,6 +982,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di lora_scale=self.lora_scale, ) + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + else: + self.enable_sequential_cpu_offload() + @classmethod def lora_state_dict( cls, From c14fc20e84ed28aca2be0466eea51d8eebe4592b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 31 Aug 2023 18:12:37 +0530 Subject: [PATCH 02/11] load_textual_inversion --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 1d39a2ff634e..ad6b76157475 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -770,7 +770,7 @@ def load_textual_inversion( if hasattr(component, "_hf_hook"): is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) remove_hook_from_module(component) From 46b0874024d0b4f0c394d533a785e30092a394dd Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 31 Aug 2023 18:13:31 +0530 Subject: [PATCH 03/11] better messaging for textual inversion. --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index ad6b76157475..15d9ceef7261 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -770,7 +770,7 @@ def load_textual_inversion( if hasattr(component, "_hf_hook"): is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) logger.info( - "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again." ) remove_hook_from_module(component) From 6c842c72f00831e6bbb19183c91eaf0533ec8da8 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 31 Aug 2023 19:00:22 +0530 Subject: [PATCH 04/11] fixes --- src/diffusers/loaders.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 15d9ceef7261..89ce480cf445 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -45,7 +45,7 @@ if is_accelerate_available(): from accelerate import init_empty_weights - from accelerate.hooks import CpuOffload, remove_hook_from_module + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module from accelerate.utils import set_module_tensor_to_device logger = logging.get_logger(__name__) @@ -765,10 +765,13 @@ def load_textual_inversion( ) # Remove any existing hooks. + is_model_cpu_offload = False + is_sequential_cpu_offload = False for _, component in self.components.items(): if isinstance(component, nn.Module): if hasattr(component, "_hf_hook"): is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) logger.info( "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again." ) @@ -930,7 +933,7 @@ def load_textual_inversion( # offload back if is_model_cpu_offload: self.enable_model_cpu_offload() - else: + elif is_sequential_cpu_offload: self.enable_sequential_cpu_offload() @@ -964,10 +967,13 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di See [`~loaders.LoraLoaderMixin.lora_state_dict`]. """ # Remove any existing hooks. + is_model_cpu_offload = False + is_sequential_cpu_offload = False for _, component in self.components.items(): if isinstance(component, nn.Module): if hasattr(component, "_hf_hook"): is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) logger.info( "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) @@ -985,7 +991,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # Offload back. if is_model_cpu_offload: self.enable_model_cpu_offload() - else: + elif is_sequential_cpu_offload: self.enable_sequential_cpu_offload() @classmethod From 2a27542b501a59c380036c79aa354893fd3390eb Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 31 Aug 2023 19:01:54 +0530 Subject: [PATCH 05/11] address PR feedback. --- src/diffusers/loaders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 89ce480cf445..9128e442617c 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -765,7 +765,7 @@ def load_textual_inversion( ) # Remove any existing hooks. - is_model_cpu_offload = False + is_model_cpu_offload = False is_sequential_cpu_offload = False for _, component in self.components.items(): if isinstance(component, nn.Module): @@ -967,7 +967,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di See [`~loaders.LoraLoaderMixin.lora_state_dict`]. """ # Remove any existing hooks. - is_model_cpu_offload = False + is_model_cpu_offload = False is_sequential_cpu_offload = False for _, component in self.components.items(): if isinstance(component, nn.Module): From b3fb9a7d3245ee192f58b9f9edd90775980850f4 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 31 Aug 2023 20:02:12 +0530 Subject: [PATCH 06/11] sdxl support. --- .../controlnet/pipeline_controlnet_sd_xl.py | 23 +++++++++++++++++++ .../pipeline_stable_diffusion_xl.py | 23 +++++++++++++++++++ .../pipeline_stable_diffusion_xl_img2img.py | 23 +++++++++++++++++++ .../pipeline_stable_diffusion_xl_inpaint.py | 23 +++++++++++++++++++ 4 files changed, 92 insertions(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index cd2f3daf55b4..c3c12f66d5f8 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1221,6 +1221,23 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. + + # Remove any existing hooks. + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + is_model_cpu_offload = False + is_sequential_cpu_offload = False + for _, component in self.components.items(): + if isinstance(component, torch.nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + remove_hook_from_module(component) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, @@ -1248,6 +1265,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di lora_scale=self.lora_scale, ) + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + @classmethod # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights def save_lora_weights( diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 2d4ef87bdf79..98377cb75f94 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -907,6 +907,23 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. + + # Remove any existing hooks. + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + is_model_cpu_offload = False + is_sequential_cpu_offload = False + for _, component in self.components.items(): + if isinstance(component, torch.nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + remove_hook_from_module(component) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, @@ -934,6 +951,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di lora_scale=self.lora_scale, ) + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + @classmethod def save_lora_weights( self, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index aada52eecbcb..0efc981fefcb 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -1065,6 +1065,23 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. + + # Remove any existing hooks. + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + is_model_cpu_offload = False + is_sequential_cpu_offload = False + for _, component in self.components.items(): + if isinstance(component, torch.nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + remove_hook_from_module(component) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, @@ -1092,6 +1109,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di lora_scale=self.lora_scale, ) + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + @classmethod # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights def save_lora_weights( diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index bc29cecdbdc9..669d2653e867 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -1370,6 +1370,23 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. + + # Remove any existing hooks. + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + is_model_cpu_offload = False + is_sequential_cpu_offload = False + for _, component in self.components.items(): + if isinstance(component, torch.nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + remove_hook_from_module(component) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, @@ -1397,6 +1414,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di lora_scale=self.lora_scale, ) + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + @classmethod # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights def save_lora_weights( From 773ff9120f2033a3f58fbf34f3acbf628267df06 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 31 Aug 2023 20:06:01 +0530 Subject: [PATCH 07/11] improve messaging --- src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py | 2 +- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 2 +- .../stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py | 2 +- .../stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index c3c12f66d5f8..0b9fc3a3754c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1226,7 +1226,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module else: - raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") is_model_cpu_offload = False is_sequential_cpu_offload = False for _, component in self.components.items(): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 98377cb75f94..ae5cc1ac58b1 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -912,7 +912,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module else: - raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") is_model_cpu_offload = False is_sequential_cpu_offload = False for _, component in self.components.items(): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 0efc981fefcb..4532c80dfb44 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -1070,7 +1070,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module else: - raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") is_model_cpu_offload = False is_sequential_cpu_offload = False for _, component in self.components.items(): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 669d2653e867..80c96459146c 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -1375,7 +1375,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module else: - raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") is_model_cpu_offload = False is_sequential_cpu_offload = False for _, component in self.components.items(): From b8b542278e4f8bead2dbd6dc496f7fd79ae1983a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 4 Sep 2023 14:57:33 +0530 Subject: [PATCH 08/11] recursive removal when cpu sequential offloading is enabled. --- src/diffusers/loaders.py | 4 ++- .../pipeline_controlnet_inpaint_sd_xl.py | 26 +++++++++++++++++++ .../controlnet/pipeline_controlnet_sd_xl.py | 5 +++- .../pipeline_stable_diffusion_xl.py | 5 +++- .../pipeline_stable_diffusion_xl_img2img.py | 5 +++- .../pipeline_stable_diffusion_xl_inpaint.py | 5 +++- 6 files changed, 45 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 9128e442617c..e1656a0ced0d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -969,6 +969,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # Remove any existing hooks. is_model_cpu_offload = False is_sequential_cpu_offload = False + recurive = False for _, component in self.components.items(): if isinstance(component, nn.Module): if hasattr(component, "_hf_hook"): @@ -977,7 +978,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di logger.info( "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) - remove_hook_from_module(component) + recurive = is_sequential_cpu_offload + remove_hook_from_module(component, recursive=recurive) state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 48c0f289d067..a59c0c8b7ef4 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -1544,6 +1544,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. + + # Remove any existing hooks. + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + else: + raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") + + is_model_cpu_offload = False + is_sequential_cpu_offload = False + recursive = False + for _, component in self.components.items(): + if isinstance(component, torch.nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recursive=recursive) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, @@ -1571,6 +1591,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di lora_scale=self.lora_scale, ) + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + @classmethod # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights def save_lora_weights( diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 3002562dcd3e..5b8dbdf20cb6 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1213,8 +1213,10 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module else: raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") + is_model_cpu_offload = False is_sequential_cpu_offload = False + recursive = False for _, component in self.components.items(): if isinstance(component, torch.nn.Module): if hasattr(component, "_hf_hook"): @@ -1223,7 +1225,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di logger.info( "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) - remove_hook_from_module(component) + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recursive=recursive) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index c153fd45b8f8..a3c7223ee2e9 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -913,8 +913,10 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module else: raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") + is_model_cpu_offload = False is_sequential_cpu_offload = False + recursive = False for _, component in self.components.items(): if isinstance(component, torch.nn.Module): if hasattr(component, "_hf_hook"): @@ -923,7 +925,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di logger.info( "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) - remove_hook_from_module(component) + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recursive=recursive) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index c7d0d532ff3a..1db17ab206e5 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -1071,8 +1071,10 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module else: raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") + is_model_cpu_offload = False is_sequential_cpu_offload = False + recursive = False for _, component in self.components.items(): if isinstance(component, torch.nn.Module): if hasattr(component, "_hf_hook"): @@ -1081,7 +1083,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di logger.info( "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) - remove_hook_from_module(component) + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recursive=recursive) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 105cf67b1deb..7b70493e5eea 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -1385,8 +1385,10 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module else: raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") + is_model_cpu_offload = False is_sequential_cpu_offload = False + recursive = False for _, component in self.components.items(): if isinstance(component, torch.nn.Module): if hasattr(component, "_hf_hook"): @@ -1395,7 +1397,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di logger.info( "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) - remove_hook_from_module(component) + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recursive=recursive) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, From 3d06c51af219a02d01b5bf2b883953b63ed8d645 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 4 Sep 2023 15:09:23 +0530 Subject: [PATCH 09/11] add: lora tests --- tests/models/test_lora_layers.py | 56 +++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 848f2f44adc9..59175f9f9efd 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -926,6 +926,42 @@ def test_a1111(self): self.assertTrue(np.allclose(images, expected, atol=1e-3)) + def test_a1111_with_model_cpu_offload(self): + generator = torch.Generator().manual_seed(0) + + pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None) + pipe.enable_model_cpu_offload() + lora_model_id = "hf-internal-testing/civitai-light-shadow-lora" + lora_filename = "light_and_shadow.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + + def test_a1111_with_sequential_cpu_offload(self): + generator = torch.Generator().manual_seed(0) + + pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None) + pipe.enable_sequential_cpu_offload() + lora_model_id = "hf-internal-testing/civitai-light-shadow-lora" + lora_filename = "light_and_shadow.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + def test_kohya_sd_v15_with_higher_dimensions(self): generator = torch.Generator().manual_seed(0) @@ -1102,10 +1138,10 @@ def test_sdxl_1_0_lora(self): generator = torch.Generator().manual_seed(0) pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipe.enable_model_cpu_offload() lora_model_id = "hf-internal-testing/sdxl-1.0-lora" lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - pipe.enable_model_cpu_offload() images = pipe( "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 @@ -1258,3 +1294,21 @@ def test_sdxl_1_0_fuse_unfuse_all(self): assert state_dicts_almost_equal(text_encoder_1_sd, new_text_encoder_1_sd) assert state_dicts_almost_equal(text_encoder_2_sd, new_text_encoder_2_sd) assert state_dicts_almost_equal(unet_sd, new_unet_sd) + + def test_sdxl_1_0_lora_with_sequential_cpu_offloading(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipe.enable_sequential_cpu_offload() + lora_model_id = "hf-internal-testing/sdxl-1.0-lora" + lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) From 340887e7a521d679c9a0843552b21c922e01ba67 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 4 Sep 2023 15:27:33 +0530 Subject: [PATCH 10/11] recruse. --- src/diffusers/loaders.py | 6 ++++-- .../controlnet/pipeline_controlnet_inpaint_sd_xl.py | 2 +- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 2 +- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 2 +- .../pipeline_stable_diffusion_xl_img2img.py | 2 +- .../pipeline_stable_diffusion_xl_inpaint.py | 2 +- 6 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e1656a0ced0d..2a7a0f524471 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -767,6 +767,7 @@ def load_textual_inversion( # Remove any existing hooks. is_model_cpu_offload = False is_sequential_cpu_offload = False + recursive = False for _, component in self.components.items(): if isinstance(component, nn.Module): if hasattr(component, "_hf_hook"): @@ -775,7 +776,8 @@ def load_textual_inversion( logger.info( "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again." ) - remove_hook_from_module(component) + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recurse=recursive) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) @@ -979,7 +981,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) recurive = is_sequential_cpu_offload - remove_hook_from_module(component, recursive=recurive) + remove_hook_from_module(component, recurse=recurive) state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index a59c0c8b7ef4..44fd6f480991 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -1563,7 +1563,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) recursive = is_sequential_cpu_offload - remove_hook_from_module(component, recursive=recursive) + remove_hook_from_module(component, recurse=recursive) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 5b8dbdf20cb6..b6f86be5680b 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1226,7 +1226,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) recursive = is_sequential_cpu_offload - remove_hook_from_module(component, recursive=recursive) + remove_hook_from_module(component, recurse=recursive) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index a3c7223ee2e9..a403660dc07d 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -926,7 +926,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) recursive = is_sequential_cpu_offload - remove_hook_from_module(component, recursive=recursive) + remove_hook_from_module(component, recurse=recursive) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 1db17ab206e5..3817cb7c53ec 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -1084,7 +1084,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) recursive = is_sequential_cpu_offload - remove_hook_from_module(component, recursive=recursive) + remove_hook_from_module(component, recurse=recursive) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 7b70493e5eea..1972037e0e03 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -1398,7 +1398,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) recursive = is_sequential_cpu_offload - remove_hook_from_module(component, recursive=recursive) + remove_hook_from_module(component, recurse=recursive) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, From 7bcf71d0bd4b8a5ed438c3543ffbb12aafea8536 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 4 Sep 2023 15:48:14 +0530 Subject: [PATCH 11/11] add: offload tests for textual inversion. --- .../stable_diffusion/test_stable_diffusion.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 7935a63eceaa..31de557a0ac3 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -1019,6 +1019,56 @@ def test_stable_diffusion_textual_inversion(self): max_diff = np.abs(expected_image - image).max() assert max_diff < 8e-1 + def test_stable_diffusion_textual_inversion_with_model_cpu_offload(self): + pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") + pipe.enable_model_cpu_offload() + pipe.load_textual_inversion("sd-concepts-library/low-poly-hd-logos-icons") + + a111_file = hf_hub_download("hf-internal-testing/text_inv_embedding_a1111_format", "winter_style.pt") + a111_file_neg = hf_hub_download( + "hf-internal-testing/text_inv_embedding_a1111_format", "winter_style_negative.pt" + ) + pipe.load_textual_inversion(a111_file) + pipe.load_textual_inversion(a111_file_neg) + + generator = torch.Generator(device="cpu").manual_seed(1) + + prompt = "An logo of a turtle in strong Style-Winter with " + neg_prompt = "Style-Winter-neg" + + image = pipe(prompt=prompt, negative_prompt=neg_prompt, generator=generator, output_type="np").images[0] + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_inv/winter_logo_style.npy" + ) + + max_diff = np.abs(expected_image - image).max() + assert max_diff < 8e-1 + + def test_stable_diffusion_textual_inversion_with_sequential_cpu_offload(self): + pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") + pipe.enable_sequential_cpu_offload() + pipe.load_textual_inversion("sd-concepts-library/low-poly-hd-logos-icons") + + a111_file = hf_hub_download("hf-internal-testing/text_inv_embedding_a1111_format", "winter_style.pt") + a111_file_neg = hf_hub_download( + "hf-internal-testing/text_inv_embedding_a1111_format", "winter_style_negative.pt" + ) + pipe.load_textual_inversion(a111_file) + pipe.load_textual_inversion(a111_file_neg) + + generator = torch.Generator(device="cpu").manual_seed(1) + + prompt = "An logo of a turtle in strong Style-Winter with " + neg_prompt = "Style-Winter-neg" + + image = pipe(prompt=prompt, negative_prompt=neg_prompt, generator=generator, output_type="np").images[0] + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_inv/winter_logo_style.npy" + ) + + max_diff = np.abs(expected_image - image).max() + assert max_diff < 8e-1 + @require_torch_2 def test_stable_diffusion_compile(self): seed = 0