From a009f1d1fe03fe622b57de5e53cbe283257f91ec Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 25 Mar 2023 09:37:05 +0530 Subject: [PATCH 1/6] improve stable unclip doc. --- .../source/en/api/pipelines/stable_unclip.mdx | 58 +++++++++++++++---- 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/docs/source/en/api/pipelines/stable_unclip.mdx b/docs/source/en/api/pipelines/stable_unclip.mdx index c8b5d58705ba..372242ae2dff 100644 --- a/docs/source/en/api/pipelines/stable_unclip.mdx +++ b/docs/source/en/api/pipelines/stable_unclip.mdx @@ -42,12 +42,9 @@ Coming soon! ### Text guided Image-to-Image Variation ```python -import requests -import torch -from PIL import Image -from io import BytesIO - from diffusers import StableUnCLIPImg2ImgPipeline +from diffusers.utils import load_image +import torch pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16" @@ -55,12 +52,10 @@ pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( pipe = pipe.to("cuda") url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/tarsila_do_amaral.png" - -response = requests.get(url) -init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = load_image(url) images = pipe(init_image).images -images[0].save("fantasy_landscape.png") +images[0].save("variation_image.png") ``` Optionally, you can also pass a prompt to `pipe` such as: @@ -69,7 +64,50 @@ Optionally, you can also pass a prompt to `pipe` such as: prompt = "A fantasy landscape, trending on artstation" images = pipe(init_image, prompt=prompt).images -images[0].save("fantasy_landscape.png") +images[0].save("variation_image_two.png") +``` + +### Memory optimization + +If you are short on GPU memory, you can enable smart CPU offloading so that models that are not needed +immediately for a computation can be offloaded to CPU: + +```python +from diffusers import StableUnCLIPImg2ImgPipeline +from diffusers.utils import load_image +import torch + +pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16" +) +# Offload to CPU. +pipe.enable_model_cpu_offload() + +url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/tarsila_do_amaral.png" +init_image = load_image(url) + +images = pipe(init_image).images +images[0] +``` + +Further memory optimizations are possible by enabling VAE slicing on the pipeline: + +```python +from diffusers import StableUnCLIPImg2ImgPipeline +from diffusers.utils import load_image +import torch + +pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16" +) +pipe.enable_model_cpu_offload() +pipe.enable_vae_slicing() + +url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/tarsila_do_amaral.png" +init_image = load_image(url) + +images = pipe(init_image).images +images[0] ``` ### StableUnCLIPPipeline From c2758e526ee63b76e39ca7d7a021296d0307313e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 29 Mar 2023 06:44:36 +0530 Subject: [PATCH 2/6] initial commits. --- src/diffusers/training_utils.py | 35 +++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 340b96e29ac5..53a2ccd920fd 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -5,10 +5,16 @@ import numpy as np import torch +import torch.nn as nn +from .loaders import AttnProcsLayers +from .models.attention_processor import LoRAAttnProcessor from .utils import deprecate +TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"] + + def enable_full_determinism(seed: int): """ Helper function for reproducible behavior during distributed training. See @@ -42,6 +48,35 @@ def set_seed(seed: int): # ^^ safe to call this function even if cuda is not available +def get_lora_layer_attribute(module_name: str) -> str: + """ + Helper function to return the layer name in the `LoRAAttnProcessor` class corresponding to the original attention + layer module. + """ + if "q_proj" in module_name: + return "to_q_lora" + elif "v_proj" in module_name: + return "to_v_lora" + elif "k_proj" in module_name: + return "to_k_lora" + else: + return "to_out_lora" + + +def get_lora_layers_for_text_encoder(text_encoder: nn.Module): + """ + Helper function to prepare the LoRA attention processors for the text encoder which almost always comes from + `transformers`. + """ + lora_attn_procs = {} + for name, module in text_encoder.named_modules(): + if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]): + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) + + text_encoder_lora_layers = AttnProcsLayers(lora_attn_procs) + return text_encoder_lora_layers + + # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ From 75c4601bd1a531c1e74b802d06c128c5b17fed54 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 29 Mar 2023 11:52:47 +0530 Subject: [PATCH 3/6] add: utilities to support text encoder + LoRA. --- src/diffusers/loaders.py | 314 ++++++++++++++++++++++++++++++- src/diffusers/training_utils.py | 35 ---- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/constants.py | 1 + 4 files changed, 315 insertions(+), 36 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index d6bb6fde6ac1..b00f4bfe02d9 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -16,10 +16,18 @@ from typing import Callable, Dict, Union import torch +import torch.nn as nn from .models.attention_processor import LoRAAttnProcessor from .models.modeling_utils import _get_model_file -from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging +from .utils import ( + DIFFUSERS_CACHE, + HF_HUB_OFFLINE, + TEXT_ENCODER_TARGET_MODULES, + deprecate, + is_safetensors_available, + logging, +) if is_safetensors_available(): @@ -32,6 +40,9 @@ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" +TEXT_ENCODER_LORA_WEIGHT_NAME = "pytorch_text_encoder_lora_weights.bin" +TEXT_ENCODER_LORA_WEIGHT_NAME_SAFE = "pytorch_text_encoder_lora_weights.safetensors" + class AttnProcsLayers(torch.nn.Module): def __init__(self, state_dict: Dict[str, torch.Tensor]): @@ -294,3 +305,304 @@ def save_function(weights, filename): save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") + + +class TextEncoderLoRAMixin: + r""" + This class is used for handling the text encoder used in our pipelines with LoRA. The methods of this class are + mostly copy-pasted from [`~UNet2DConditionLoadersMixin`]. We couldn't fully reuse the class because we cannot do + things like `self.set_attn_processor()`. + + Args: + text_encoder (`nn.Module`): + The text encoder module underlying a [`~DiffusionPipeline`]. + """ + + def __init__(self, text_encoder: nn.Module): + self.text_encoder = text_encoder + self.device = text_encoder.device + self.dtype = text_encoder.dtype + self._initialize_lora_layers() + + def _initialize_lora_layers(self): + self.lora_attn_procs = {} + for name, module in self.text_encoder.named_modules(): + if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]): + self.lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=module.out_features, cross_attention_dim=None + ) + + self.text_encoder_lora_layers = AttnProcsLayers(self.lora_attn_procs) + + def load_attn_procs( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs + ) -> nn.Module: + r""" + Load pretrained attention processor layers into + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). Instead + of setting the attention processing layers (as done in [`~UNet2DConditionLoadersMixin.load_attn_procs`]), we + use the LoRA attention layers to monkey-patch the forward passes of the attention modules of the + `text_encoder`. + + + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + Returns: + `nn.Module`: The text encoder module with the forward passes of its attention modules monkey-patched. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + """ + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or TEXT_ENCODER_LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except IOError as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or TEXT_ENCODER_LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + + # fill attn processors + attn_processors = {} + + is_lora = all("lora" in k for k in state_dict.keys()) + + if is_lora: + lora_grouped_dict = defaultdict(dict) + for key, value in state_dict.items(): + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + lora_grouped_dict[attn_processor_key][sub_key] = value + + for key, value_dict in lora_grouped_dict.items(): + rank = value_dict["to_k_lora.down.weight"].shape[0] + cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] + hidden_size = value_dict["to_k_lora.up.weight"].shape[0] + + attn_processors[key] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank + ) + attn_processors[key].load_state_dict(value_dict) + + else: + raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.") + + # set correct dtype & device + attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()} + + return self._modify_text_encoder(attn_processors) + + def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]) -> nn.Module: + r""" + Monkey-patches the forward passes of attention modules of the text encoder. + + Args: + attn_processors: Dict[str, `LoRAAttnProcessor`]: + A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`]. + Returns: + `nn.Module`: The modified text encoder. + """ + # Loop over the original attention modules. + for name, _ in self.text_encoder.named_modules(): + if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]): + # Retrieve the module and its corresponding LoRA processor. + module = self.text_encoder.get_submodule(name) + # Construct a new function that performs the LoRA merging. We will monkey patch + # this forward pass. + lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) + old_forward = module.forward + + def new_forward(x): + return old_forward(x) + lora_layer(x) + + # Monkey-patch. + module.forward = new_forward + return self.text_encoder + + def _get_lora_layer_attribute(self, name: str) -> str: + if "q_proj" in name: + return "to_q_lora" + elif "v_proj" in name: + return "to_v_lora" + elif "k_proj" in name: + return "to_k_lora" + else: + return "to_out_lora" + + def save_attn_procs( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = False, + **kwargs, + ): + r""" + Save an attention processor to a directory, so that it can be re-loaded using the + `[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + """ + weight_name = weight_name or deprecate( + "weights_name", + "0.18.0", + "`weights_name` is deprecated, please use `weight_name` instead.", + take_from=kwargs, + ) + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + if save_function is None: + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + + else: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + + model_to_save = self.text_encoder_lora_layers + + # Save the model + state_dict = model_to_save.state_dict() + + if weight_name is None: + if safe_serialization: + weight_name = TEXT_ENCODER_LORA_WEIGHT_NAME_SAFE + else: + weight_name = TEXT_ENCODER_LORA_WEIGHT_NAME + + # Save the model + save_function(state_dict, os.path.join(save_directory, weight_name)) + + logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 53a2ccd920fd..340b96e29ac5 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -5,16 +5,10 @@ import numpy as np import torch -import torch.nn as nn -from .loaders import AttnProcsLayers -from .models.attention_processor import LoRAAttnProcessor from .utils import deprecate -TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"] - - def enable_full_determinism(seed: int): """ Helper function for reproducible behavior during distributed training. See @@ -48,35 +42,6 @@ def set_seed(seed: int): # ^^ safe to call this function even if cuda is not available -def get_lora_layer_attribute(module_name: str) -> str: - """ - Helper function to return the layer name in the `LoRAAttnProcessor` class corresponding to the original attention - layer module. - """ - if "q_proj" in module_name: - return "to_q_lora" - elif "v_proj" in module_name: - return "to_v_lora" - elif "k_proj" in module_name: - return "to_k_lora" - else: - return "to_out_lora" - - -def get_lora_layers_for_text_encoder(text_encoder: nn.Module): - """ - Helper function to prepare the LoRA attention processors for the text encoder which almost always comes from - `transformers`. - """ - lora_attn_procs = {} - for name, module in text_encoder.named_modules(): - if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]): - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) - - text_encoder_lora_layers = AttnProcsLayers(lora_attn_procs) - return text_encoder_lora_layers - - # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 615804c91a19..93ccf119c60c 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -30,6 +30,7 @@ ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, + TEXT_ENCODER_TARGET_MODULES, WEIGHTS_NAME, ) from .deprecation_utils import deprecate diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index b9e60a2a873b..1134ba6fb656 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -30,3 +30,4 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] +TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"] From b500174cf6f523b26415b94388c8523267df0295 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 29 Mar 2023 13:03:38 +0530 Subject: [PATCH 4/6] add: tests. --- tests/test_text_encoder_lora.py | 99 +++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 tests/test_text_encoder_lora.py diff --git a/tests/test_text_encoder_lora.py b/tests/test_text_encoder_lora.py new file mode 100644 index 000000000000..27060da7a442 --- /dev/null +++ b/tests/test_text_encoder_lora.py @@ -0,0 +1,99 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import os +import tempfile +import unittest + +import torch +from transformers import CLIPTextConfig, CLIPTextModel + +from diffusers.loaders import TextEncoderLoRAMixin +from diffusers.utils import torch_device + + +class TextEncoderLoRATests(unittest.TestCase): + def get_dummy_components(self): + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config).to(torch_device) + text_encoder_lora_wrapper = TextEncoderLoRAMixin(copy.deepcopy(text_encoder)) + lora_attn_procs = text_encoder_lora_wrapper.lora_attn_procs + text_encoder_lora = text_encoder_lora_wrapper._modify_text_encoder(lora_attn_procs) + return text_encoder, text_encoder_lora, text_encoder_lora_wrapper + + def get_dummy_inputs(self): + batch_size = 1 + sequence_length = 10 + generator = torch.manual_seed(0) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + return input_ids + + def test_lora_default_case(self): + text_encoder, text_encoder_lora, _ = self.get_dummy_components() + inputs = self.get_dummy_inputs() + + with torch.no_grad(): + original_outputs = text_encoder(inputs)[0] + lora_outputs = text_encoder_lora(inputs)[0] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(original_outputs, lora_outputs)) + + def test_lora_save_load(self): + text_encoder, _, text_encoder_lora_wrapper = self.get_dummy_components() + inputs = self.get_dummy_inputs() + + with torch.no_grad(): + original_outputs = text_encoder(inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdirname: + text_encoder_lora_wrapper.save_attn_procs(tmpdirname) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_text_encoder_lora_weights.bin"))) + text_encoder_lora = text_encoder_lora_wrapper.load_attn_procs(tmpdirname) + + with torch.no_grad(): + lora_outputs = text_encoder_lora(inputs)[0] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(original_outputs, lora_outputs)) + + def test_lora_save_load_safetensors(self): + text_encoder, _, text_encoder_lora_wrapper = self.get_dummy_components() + inputs = self.get_dummy_inputs() + + with torch.no_grad(): + original_outputs = text_encoder(inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdirname: + text_encoder_lora_wrapper.save_attn_procs(tmpdirname, safe_serialization=True) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_text_encoder_lora_weights.safetensors"))) + text_encoder_lora = text_encoder_lora_wrapper.load_attn_procs(tmpdirname) + + with torch.no_grad(): + lora_outputs = text_encoder_lora(inputs)[0] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(original_outputs, lora_outputs)) From 6fcbb5a77af8a615fd26e6adbef419fa027f1165 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 29 Mar 2023 13:30:18 +0530 Subject: [PATCH 5/6] add: entry to the docs. --- docs/source/en/api/loaders.mdx | 4 ++++ src/diffusers/loaders.py | 7 +++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/loaders.mdx b/docs/source/en/api/loaders.mdx index 1d55bd03c064..1243930bec9f 100644 --- a/docs/source/en/api/loaders.mdx +++ b/docs/source/en/api/loaders.mdx @@ -28,3 +28,7 @@ API to load such adapter neural networks via the [`loaders.py` module](https://g ### UNet2DConditionLoadersMixin [[autodoc]] loaders.UNet2DConditionLoadersMixin + +### TextEncoderLoRAMixin + +[[autodoc]] loaders.TextEncoderLoRAMixin diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index b00f4bfe02d9..87035c38f72f 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -548,6 +548,7 @@ def _get_lora_layer_attribute(self, name: str) -> str: def save_attn_procs( self, save_directory: Union[str, os.PathLike], + text_encoder_lora_layers: nn.Module, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -556,11 +557,13 @@ def save_attn_procs( ): r""" Save an attention processor to a directory, so that it can be re-loaded using the - `[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method. + [`~loaders.TextEncoderLoRAMixin.load_attn_procs`] method. Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. + text_encoder_lora_layers (`nn.Module`): + LoRA trainable parameters provided as `nn.Module`. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful when in distributed training like TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on @@ -591,7 +594,7 @@ def save_function(weights, filename): os.makedirs(save_directory, exist_ok=True) - model_to_save = self.text_encoder_lora_layers + model_to_save = text_encoder_lora_layers # Save the model state_dict = model_to_save.state_dict() From 3b65ac142fc9502f0cc39fdd907e760ada2ec382 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 29 Mar 2023 14:02:13 +0530 Subject: [PATCH 6/6] fix: tests. --- tests/test_text_encoder_lora.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_text_encoder_lora.py b/tests/test_text_encoder_lora.py index 27060da7a442..0ec150e22755 100644 --- a/tests/test_text_encoder_lora.py +++ b/tests/test_text_encoder_lora.py @@ -70,7 +70,7 @@ def test_lora_save_load(self): original_outputs = text_encoder(inputs)[0] with tempfile.TemporaryDirectory() as tmpdirname: - text_encoder_lora_wrapper.save_attn_procs(tmpdirname) + text_encoder_lora_wrapper.save_attn_procs(tmpdirname, text_encoder_lora_wrapper.text_encoder_lora_layers) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_text_encoder_lora_weights.bin"))) text_encoder_lora = text_encoder_lora_wrapper.load_attn_procs(tmpdirname) @@ -88,7 +88,9 @@ def test_lora_save_load_safetensors(self): original_outputs = text_encoder(inputs)[0] with tempfile.TemporaryDirectory() as tmpdirname: - text_encoder_lora_wrapper.save_attn_procs(tmpdirname, safe_serialization=True) + text_encoder_lora_wrapper.save_attn_procs( + tmpdirname, text_encoder_lora_wrapper.text_encoder_lora_layers, safe_serialization=True + ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_text_encoder_lora_weights.safetensors"))) text_encoder_lora = text_encoder_lora_wrapper.load_attn_procs(tmpdirname)