From f8d04f8eac930439bfff7fe78944681181dd2c6e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 31 Mar 2023 17:27:44 +0530 Subject: [PATCH 1/9] add: first draft for a better LoRA enabler. --- src/diffusers/loaders.py | 442 +++++++++++++++++- .../pipeline_stable_diffusion.py | 7 +- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/constants.py | 1 + 4 files changed, 439 insertions(+), 12 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 265ea92625f5..d5dd5b45d9be 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -21,6 +21,7 @@ from .utils import ( DIFFUSERS_CACHE, HF_HUB_OFFLINE, + TEXT_ENCODER_TARGET_MODULES, _get_model_file, deprecate, is_safetensors_available, @@ -81,12 +82,12 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict r""" Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be defined in - [cross_attention.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py) + [`cross_attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py) and be a `torch.nn.Module` class. - This function is experimental and might change in the future. + This function is experimental and might change in the future. @@ -125,7 +126,6 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo (either remote in huggingface.co or downloaded locally), you can specify the folder name here. - mirror (`str`, *optional*): Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. @@ -133,8 +133,8 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models). + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). """ @@ -250,7 +250,7 @@ def save_attn_procs( ): r""" Save an attention processor to a directory, so that it can be re-loaded using the - `[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method. + [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method. Arguments: save_directory (`str` or `os.PathLike`): @@ -288,7 +288,7 @@ def save_function(weights, filename): model_to_save = AttnProcsLayers(self.attn_processors) # Save the model - state_dict = model_to_save.state_dict() + state_dict = {"unet": model_to_save.state_dict()} if weight_name is None: if safe_serialization: @@ -372,12 +372,12 @@ def load_textual_inversion( - This function is experimental and might change in the future. + This function is experimental and might change in the future. Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`): + pretrained_model_name_or_path (`str` or `os.PathLike`): Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. @@ -566,4 +566,426 @@ def load_textual_inversion( for token_id, embedding in zip(token_ids, embeddings): self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding - logger.info("Loaded textual inversion embedding for {token}.") + logger.info(f"Loaded textual inversion embedding for {token}.") + + +class LoraLoaderMixin: + r""" + Utility class for handling the loading LoRA layers into UNet (of class [`~UNet2DConditionModel`]) and Text Encoder + (of class [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)). + + + + This function is experimental and might change in the future. + + + """ + text_encoder_name = "text_encoder" + unet_name = "unet" + + def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + r""" + Load pretrained attention processor layers (such as LoRA) into [`~UNet2DConditionModel`] and + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)). + + + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + """ + # Load the main state dict first which has the LoRA layers for either of + # UNet and text encoder or both. + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except IOError as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + + # Load the layers corresponding to UNet. + if state_dict.get(self.unet_name, None) is not None: + logger.info(f"Loading {self.unet_name}.") + self.unet.load_attn_procs(state_dict[self.unet_name]) + + # Load the layers corresponding to text encoder and make necessary adjustments. + if state_dict.get(self.text_encoder_name, None) is not None: + logger.info(f"Loading {self.text_encoder_name}.") + attn_procs_text_encoder = self.load_attn_procs(state_dict[self.text_encoder_name], **kwargs) + self._modify_text_encoder(attn_procs_text_encoder) + + def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): + r""" + Monkey-patches the forward passes of attention modules of the text encoder. + + Parameters: + attn_processors: Dict[str, `LoRAAttnProcessor`]: + A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`]. + """ + # Loop over the original attention modules. + for name, _ in self.text_encoder.named_modules(): + if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]): + # Retrieve the module and its corresponding LoRA processor. + module = self.text_encoder.get_submodule(name) + # Construct a new function that performs the LoRA merging. We will monkey patch + # this forward pass. + lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) + old_forward = module.forward + + def new_forward(x): + return old_forward(x) + lora_layer(x) + + # Monkey-patch. + module.forward = new_forward + + def _get_lora_layer_attribute(self, name: str) -> str: + if "q_proj" in name: + return "to_q_lora" + elif "v_proj" in name: + return "to_v_lora" + elif "k_proj" in name: + return "to_k_lora" + else: + return "to_out_lora" + + def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + r""" + Load pretrained attention processor layers for + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). + + + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + """ + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except IOError as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + + # fill attn processors + attn_processors = {} + + is_lora = all("lora" in k for k in state_dict.keys()) + + if is_lora: + lora_grouped_dict = defaultdict(dict) + for key, value in state_dict.items(): + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + lora_grouped_dict[attn_processor_key][sub_key] = value + + for key, value_dict in lora_grouped_dict.items(): + rank = value_dict["to_k_lora.down.weight"].shape[0] + cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] + hidden_size = value_dict["to_k_lora.up.weight"].shape[0] + + attn_processors[key] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank + ) + attn_processors[key].load_state_dict(value_dict) + + else: + raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.") + + # set correct dtype & device + attn_processors = { + k: v.to(device=self.device, dtype=self.text_encoder.dtype) for k, v in attn_processors.items() + } + return attn_processors + + @classmethod + def save_lora_weights( + self, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, torch.nn.Module] = None, + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = False, + **kwargs, + ): + r""" + Save the LoRA parameters corresponding to the UNet and the text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + unet_lora_layers (`Dict[str, torch.nn.Module`]): + State dict of the LoRA layers corresponding to the UNet. Specifying this helps to make the + serialization process easier and cleaner. + text_encoder_lora_layers (`Dict[str, torch.nn.Module`]): + State dict of the LoRA layers corresponding to the `text_encoder`. Since the `text_encoder` comes from + `transformers`, we cannot rejig it. That is why we have to explicitly pass the text encoder LoRA state + dict. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + """ + weight_name = weight_name or deprecate( + "weights_name", + "0.18.0", + "`weights_name` is deprecated, please use `weight_name` instead.", + take_from=kwargs, + ) + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + if save_function is None: + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + + else: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + + state_dict = {} + if unet_lora_layers is not None: + state_dict.update({self.unet_name: unet_lora_layers.state_dict()}) + if text_encoder_lora_layers is not None: + state_dict.update({self.text_encoder_name: text_encoder_lora_layers.state_dict()}) + + # Save the model + if weight_name is None: + if safe_serialization: + weight_name = LORA_WEIGHT_NAME_SAFE + else: + weight_name = LORA_WEIGHT_NAME + + # Save the model + save_function(state_dict, os.path.join(save_directory, weight_name)) + logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 73b9178e3ab1..4d060159209b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -20,7 +20,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict -from ...loaders import TextualInversionLoaderMixin +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -53,7 +53,7 @@ """ -class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): +class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -109,6 +109,9 @@ def __init__( new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) + self.text_encoder_name = text_encoder.config["architectures"].lower() + self.unet_name = unet.config["_class_name"].lower() + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 3a1103ac1adf..bb159d9db375 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -30,6 +30,7 @@ ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, + TEXT_ENCODER_TARGET_MODULES, WEIGHTS_NAME, ) from .deprecation_utils import deprecate diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index b9e60a2a873b..1134ba6fb656 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -30,3 +30,4 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] +TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"] From e81913865affa661e243ddba5eb8a96ddbb46e81 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 31 Mar 2023 17:35:35 +0530 Subject: [PATCH 2/9] make fix-copies. --- .../pipelines/alt_diffusion/pipeline_alt_diffusion.py | 3 +++ .../pipeline_stable_diffusion_inpaint_legacy.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index c5bb8f9ac7b1..e63485180bf9 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -106,6 +106,9 @@ def __init__( new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) + self.text_encoder_name = text_encoder.config["architectures"].lower() + self.unet_name = unet.config["_class_name"].lower() + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index b7a0c942bbe2..74d2c6ec1c72 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -139,6 +139,9 @@ def __init__( new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) + self.text_encoder_name = text_encoder.config["architectures"].lower() + self.unet_name = unet.config["_class_name"].lower() + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." From a1670a01972a3efb9dc7a7514825c981e3acd63a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 3 Apr 2023 09:18:13 +0530 Subject: [PATCH 3/9] feat: backward compatibility. --- src/diffusers/loaders.py | 45 +++++++++++-------- .../alt_diffusion/pipeline_alt_diffusion.py | 3 -- .../pipeline_stable_diffusion.py | 3 -- ...ipeline_stable_diffusion_inpaint_legacy.py | 3 -- 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index d5dd5b45d9be..957b3e2aa646 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -288,7 +288,7 @@ def save_function(weights, filename): model_to_save = AttnProcsLayers(self.attn_processors) # Save the model - state_dict = {"unet": model_to_save.state_dict()} + state_dict = model_to_save.state_dict() if weight_name is None: if safe_serialization: @@ -714,16 +714,32 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di else: state_dict = pretrained_model_name_or_path_or_dict - # Load the layers corresponding to UNet. - if state_dict.get(self.unet_name, None) is not None: - logger.info(f"Loading {self.unet_name}.") - self.unet.load_attn_procs(state_dict[self.unet_name]) - - # Load the layers corresponding to text encoder and make necessary adjustments. - if state_dict.get(self.text_encoder_name, None) is not None: - logger.info(f"Loading {self.text_encoder_name}.") - attn_procs_text_encoder = self.load_attn_procs(state_dict[self.text_encoder_name], **kwargs) - self._modify_text_encoder(attn_procs_text_encoder) + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` should either have + # (1) two keys at the root-level: `unet` and `text_encoder`. + # (2) OR, one of the two keys: `unet` or `text_encoder`. + if len(list(state_dict.keys())) in [1, 2]: + # Load the layers corresponding to UNet. + if state_dict.get(self.unet_name, None) is not None: + logger.info(f"Loading {self.unet_name}.") + self.unet.load_attn_procs(state_dict[self.unet_name]) + + # Load the layers corresponding to text encoder and make necessary adjustments. + if state_dict.get(self.text_encoder_name, None) is not None: + logger.info(f"Loading {self.text_encoder_name}.") + attn_procs_text_encoder = self.load_attn_procs(state_dict[self.text_encoder_name]) + self._modify_text_encoder(attn_procs_text_encoder) + # Otherwise, we're dealing with the old format. This means the `state_dict` should only + # contain the module names of the `unet` as its keys WITHOUT any high-level keys like + # `unet`. + else: + self.unet.load_attn_procs(state_dict) + logger.warning( + "You have saved the LoRA weights using the old format. This will be" + " deprecated soon. To convert the old LoRA weights to the new format, you can first load them" + " in a dictionary and then create a new dictionary like the following:" + " `{new_dictionary.update('unet': old_dictionary)}`." + ) def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): r""" @@ -928,7 +944,6 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = False, - **kwargs, ): r""" Save the LoRA parameters corresponding to the UNet and the text encoder. @@ -952,12 +967,6 @@ def save_lora_weights( need to replace `torch.save` by another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. """ - weight_name = weight_name or deprecate( - "weights_name", - "0.18.0", - "`weights_name` is deprecated, please use `weight_name` instead.", - take_from=kwargs, - ) if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index e63485180bf9..c5bb8f9ac7b1 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -106,9 +106,6 @@ def __init__( new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - self.text_encoder_name = text_encoder.config["architectures"].lower() - self.unet_name = unet.config["_class_name"].lower() - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 4d060159209b..56810e8759b5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -109,9 +109,6 @@ def __init__( new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - self.text_encoder_name = text_encoder.config["architectures"].lower() - self.unet_name = unet.config["_class_name"].lower() - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 74d2c6ec1c72..b7a0c942bbe2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -139,9 +139,6 @@ def __init__( new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - self.text_encoder_name = text_encoder.config["architectures"].lower() - self.unet_name = unet.config["_class_name"].lower() - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." From 8f794b25cfc09bfa516094317cfd5e3a3f343ce1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 3 Apr 2023 09:20:55 +0530 Subject: [PATCH 4/9] add: entry to the docs. --- docs/source/en/api/loaders.mdx | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/source/en/api/loaders.mdx b/docs/source/en/api/loaders.mdx index 1d55bd03c064..8cbf21b8e0cf 100644 --- a/docs/source/en/api/loaders.mdx +++ b/docs/source/en/api/loaders.mdx @@ -28,3 +28,11 @@ API to load such adapter neural networks via the [`loaders.py` module](https://g ### UNet2DConditionLoadersMixin [[autodoc]] loaders.UNet2DConditionLoadersMixin + +### TextualInversionLoaderMixin + +[[autodoc]] loaders.TextualInversionLoaderMixin + +### LoraLoaderMixin + +[[autodoc]] loaders.LoraLoaderMixin From bac46f68fb79c5dd3bb1e752c39da3a6762ba95e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 3 Apr 2023 11:46:54 +0530 Subject: [PATCH 5/9] add: tests. --- tests/test_lora_layers.py | 213 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 tests/test_lora_layers.py diff --git a/tests/test_lora_layers.py b/tests/test_lora_layers.py new file mode 100644 index 000000000000..9bfce3646b9f --- /dev/null +++ b/tests/test_lora_layers.py @@ -0,0 +1,213 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile +import unittest + +import torch +import torch.nn as nn +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin +from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device + + +def create_unet_lora_layers(unet: nn.Module): + lora_attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + unet_lora_layers = AttnProcsLayers(lora_attn_procs) + return lora_attn_procs, unet_lora_layers + + +def create_text_encoder_lora_layers(text_encoder: nn.Module): + text_lora_attn_procs = {} + for name, module in text_encoder.named_modules(): + if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]): + text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) + text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) + return text_encoder_lora_layers + + +class LoraLoaderMixinTests(unittest.TestCase): + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet) + text_encoder_lora_layers = create_text_encoder_lora_layers(text_encoder) + + pipeline_components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + lora_components = { + "unet_lora_layers": unet_lora_layers, + "text_encoder_lora_layers": text_encoder_lora_layers, + "unet_lora_attn_procs": unet_lora_attn_procs, + } + return pipeline_components, lora_components + + def get_dummy_inputs(self): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + } + + return noise, input_ids, pipeline_inputs + + def test_lora_save_load(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + noise, input_ids, pipeline_inputs = self.get_dummy_inputs() + + original_images = sd_pipe(**pipeline_inputs).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + + def test_lora_save_load_safetensors(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + noise, input_ids, pipeline_inputs = self.get_dummy_inputs() + + original_images = sd_pipe(**pipeline_inputs).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + + def test_lora_load_legacy(self): + pipeline_components, lora_components = self.get_dummy_components() + unet_lora_attn_procs = lora_components["unet_lora_attn_procs"] + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + noise, input_ids, pipeline_inputs = self.get_dummy_inputs() + + original_images = sd_pipe(**pipeline_inputs).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + unet = sd_pipe.unet + unet.set_attn_processor(unet_lora_attn_procs) + unet.save_attn_procs(tmpdirname) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) From 8e02ffde106d8afabcda4a67600d730a4ddd0257 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 3 Apr 2023 13:02:54 +0530 Subject: [PATCH 6/9] fix: docs. --- src/diffusers/loaders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 957b3e2aa646..bcc67135d792 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -571,7 +571,7 @@ def load_textual_inversion( class LoraLoaderMixin: r""" - Utility class for handling the loading LoRA layers into UNet (of class [`~UNet2DConditionModel`]) and Text Encoder + Utility class for handling the loading LoRA layers into UNet (of class [`UNet2DConditionModel`]) and Text Encoder (of class [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)). @@ -585,7 +585,7 @@ class LoraLoaderMixin: def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): r""" - Load pretrained attention processor layers (such as LoRA) into [`~UNet2DConditionModel`] and + Load pretrained attention processor layers (such as LoRA) into [`UNet2DConditionModel`] and [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)). From f91f6bd1ef954eef1f8fadea995b42e3db1a39a3 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 4 Apr 2023 09:06:38 +0530 Subject: [PATCH 7/9] fix: norm group test for UNet3D. --- tests/models/test_models_unet_3d_condition.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index 729367a0c164..5a0d74a3ea5a 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -119,12 +119,11 @@ def test_xformers_enable_works(self): == "XFormersAttnProcessor" ), "xformers is not enabled" - # Overriding because `block_out_channels` needs to be different for this model. + # Overriding to set `norm_num_groups` needs to be different for this model. def test_forward_with_norm_groups(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict["norm_num_groups"] = 32 - init_dict["block_out_channels"] = (32, 64, 64, 64) model = self.model_class(**init_dict) model.to(torch_device) From cc375ad2daf29a156c090b016d808c94bdbbe75d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 5 Apr 2023 09:09:44 +0530 Subject: [PATCH 8/9] feat: add support for flat dicts. --- src/diffusers/loaders.py | 60 +++++++++++++++++++++++++-------------- tests/test_lora_layers.py | 2 +- 2 files changed, 40 insertions(+), 22 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index bcc67135d792..4782e54c861e 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -715,30 +715,36 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di state_dict = pretrained_model_name_or_path_or_dict # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), - # then the `state_dict` should either have - # (1) two keys at the root-level: `unet` and `text_encoder`. - # (2) OR, one of the two keys: `unet` or `text_encoder`. - if len(list(state_dict.keys())) in [1, 2]: - # Load the layers corresponding to UNet. - if state_dict.get(self.unet_name, None) is not None: - logger.info(f"Loading {self.unet_name}.") - self.unet.load_attn_procs(state_dict[self.unet_name]) - - # Load the layers corresponding to text encoder and make necessary adjustments. - if state_dict.get(self.text_encoder_name, None) is not None: - logger.info(f"Loading {self.text_encoder_name}.") - attn_procs_text_encoder = self.load_attn_procs(state_dict[self.text_encoder_name]) - self._modify_text_encoder(attn_procs_text_encoder) + # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + + # Load the layers corresponding to UNet. + if all(key.startswith(self.unet_name) for key in keys): + logger.info(f"Loading {self.unet_name}.") + unet_lora_state_dict = {k: v for k, v in state_dict.items() if k.startswith(self.unet_name)} + self.unet.load_attn_procs(unet_lora_state_dict) + + # Load the layers corresponding to text encoder and make necessary adjustments. + elif all(key.startswith(self.text_encoder_name) for key in keys): + logger.info(f"Loading {self.text_encoder_name}.") + text_encoder_lora_state_dict = { + k: v for k, v in state_dict.items() if k.startswith(self.text_encoder_name) + } + attn_procs_text_encoder = self.load_attn_procs(text_encoder_lora_state_dict) + self._modify_text_encoder(attn_procs_text_encoder) + # Otherwise, we're dealing with the old format. This means the `state_dict` should only - # contain the module names of the `unet` as its keys WITHOUT any high-level keys like - # `unet`. - else: + # contain the module names of the `unet` as its keys WITHOUT any prefix. + elif not all( + key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() + ): self.unet.load_attn_procs(state_dict) logger.warning( "You have saved the LoRA weights using the old format. This will be" " deprecated soon. To convert the old LoRA weights to the new format, you can first load them" " in a dictionary and then create a new dictionary like the following:" - " `{new_dictionary.update('unet': old_dictionary)}`." + " `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." ) def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): @@ -826,6 +832,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please refer to the mirror site for more information. + Returns: + `Dict[name, LoRAAttnProcessor]`: Mapping between the module names and their corresponding + [`LoRAAttnProcessor`]. + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated @@ -982,11 +992,20 @@ def save_function(weights, filename): os.makedirs(save_directory, exist_ok=True) + # Create a flat dictionary. state_dict = {} if unet_lora_layers is not None: - state_dict.update({self.unet_name: unet_lora_layers.state_dict()}) + unet_lora_state_dict = { + f"{self.unet_name}.{module_name}": param + for module_name, param in unet_lora_layers.state_dict().items() + } + state_dict.update(unet_lora_state_dict) if text_encoder_lora_layers is not None: - state_dict.update({self.text_encoder_name: text_encoder_lora_layers.state_dict()}) + text_encoder_lora_state_dict = { + f"{self.text_encoder_name}.{module_name}": param + for module_name, param in text_encoder_lora_layers.state_dict().items() + } + state_dict.update(text_encoder_lora_state_dict) # Save the model if weight_name is None: @@ -995,6 +1014,5 @@ def save_function(weights, filename): else: weight_name = LORA_WEIGHT_NAME - # Save the model save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") diff --git a/tests/test_lora_layers.py b/tests/test_lora_layers.py index 9bfce3646b9f..9bcdc5d93301 100644 --- a/tests/test_lora_layers.py +++ b/tests/test_lora_layers.py @@ -187,7 +187,7 @@ def test_lora_save_load_safetensors(self): # Outputs shouldn't match. self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) - def test_lora_load_legacy(self): + def test_lora_save_load_legacy(self): pipeline_components, lora_components = self.get_dummy_components() unet_lora_attn_procs = lora_components["unet_lora_attn_procs"] sd_pipe = StableDiffusionPipeline(**pipeline_components) From 0a974969ebe2809c4668a6b50be2b97f74131617 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 6 Apr 2023 20:19:55 +0530 Subject: [PATCH 9/9] add depcrcation message instead of warning. --- src/diffusers/loaders.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 9386bdf956f2..31939ca4b481 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -740,12 +740,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() ): self.unet.load_attn_procs(state_dict) - logger.warning( - "You have saved the LoRA weights using the old format. This will be" - " deprecated soon. To convert the old LoRA weights to the new format, you can first load them" - " in a dictionary and then create a new dictionary like the following:" - " `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." - ) + deprecation_message = "You have saved the LoRA weights using the old format. This will be" + " deprecated soon. To convert the old LoRA weights to the new format, you can first load them" + " in a dictionary and then create a new dictionary like the following:" + " `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." + deprecate("legacy LoRA weights", "1.0.0", deprecation_message, standard_warn=False) def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): r"""