diff --git a/docs/source/en/api/loaders.mdx b/docs/source/en/api/loaders.mdx index 1d55bd03c064..1243930bec9f 100644 --- a/docs/source/en/api/loaders.mdx +++ b/docs/source/en/api/loaders.mdx @@ -28,3 +28,7 @@ API to load such adapter neural networks via the [`loaders.py` module](https://g ### UNet2DConditionLoadersMixin [[autodoc]] loaders.UNet2DConditionLoadersMixin + +### TextEncoderLoRAMixin + +[[autodoc]] loaders.TextEncoderLoRAMixin diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index d6bb6fde6ac1..87035c38f72f 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -16,10 +16,18 @@ from typing import Callable, Dict, Union import torch +import torch.nn as nn from .models.attention_processor import LoRAAttnProcessor from .models.modeling_utils import _get_model_file -from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging +from .utils import ( + DIFFUSERS_CACHE, + HF_HUB_OFFLINE, + TEXT_ENCODER_TARGET_MODULES, + deprecate, + is_safetensors_available, + logging, +) if is_safetensors_available(): @@ -32,6 +40,9 @@ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" +TEXT_ENCODER_LORA_WEIGHT_NAME = "pytorch_text_encoder_lora_weights.bin" +TEXT_ENCODER_LORA_WEIGHT_NAME_SAFE = "pytorch_text_encoder_lora_weights.safetensors" + class AttnProcsLayers(torch.nn.Module): def __init__(self, state_dict: Dict[str, torch.Tensor]): @@ -294,3 +305,307 @@ def save_function(weights, filename): save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") + + +class TextEncoderLoRAMixin: + r""" + This class is used for handling the text encoder used in our pipelines with LoRA. The methods of this class are + mostly copy-pasted from [`~UNet2DConditionLoadersMixin`]. We couldn't fully reuse the class because we cannot do + things like `self.set_attn_processor()`. + + Args: + text_encoder (`nn.Module`): + The text encoder module underlying a [`~DiffusionPipeline`]. + """ + + def __init__(self, text_encoder: nn.Module): + self.text_encoder = text_encoder + self.device = text_encoder.device + self.dtype = text_encoder.dtype + self._initialize_lora_layers() + + def _initialize_lora_layers(self): + self.lora_attn_procs = {} + for name, module in self.text_encoder.named_modules(): + if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]): + self.lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=module.out_features, cross_attention_dim=None + ) + + self.text_encoder_lora_layers = AttnProcsLayers(self.lora_attn_procs) + + def load_attn_procs( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs + ) -> nn.Module: + r""" + Load pretrained attention processor layers into + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). Instead + of setting the attention processing layers (as done in [`~UNet2DConditionLoadersMixin.load_attn_procs`]), we + use the LoRA attention layers to monkey-patch the forward passes of the attention modules of the + `text_encoder`. + + + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + Returns: + `nn.Module`: The text encoder module with the forward passes of its attention modules monkey-patched. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + """ + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or TEXT_ENCODER_LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except IOError as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or TEXT_ENCODER_LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + + # fill attn processors + attn_processors = {} + + is_lora = all("lora" in k for k in state_dict.keys()) + + if is_lora: + lora_grouped_dict = defaultdict(dict) + for key, value in state_dict.items(): + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + lora_grouped_dict[attn_processor_key][sub_key] = value + + for key, value_dict in lora_grouped_dict.items(): + rank = value_dict["to_k_lora.down.weight"].shape[0] + cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] + hidden_size = value_dict["to_k_lora.up.weight"].shape[0] + + attn_processors[key] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank + ) + attn_processors[key].load_state_dict(value_dict) + + else: + raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.") + + # set correct dtype & device + attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()} + + return self._modify_text_encoder(attn_processors) + + def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]) -> nn.Module: + r""" + Monkey-patches the forward passes of attention modules of the text encoder. + + Args: + attn_processors: Dict[str, `LoRAAttnProcessor`]: + A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`]. + Returns: + `nn.Module`: The modified text encoder. + """ + # Loop over the original attention modules. + for name, _ in self.text_encoder.named_modules(): + if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]): + # Retrieve the module and its corresponding LoRA processor. + module = self.text_encoder.get_submodule(name) + # Construct a new function that performs the LoRA merging. We will monkey patch + # this forward pass. + lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) + old_forward = module.forward + + def new_forward(x): + return old_forward(x) + lora_layer(x) + + # Monkey-patch. + module.forward = new_forward + return self.text_encoder + + def _get_lora_layer_attribute(self, name: str) -> str: + if "q_proj" in name: + return "to_q_lora" + elif "v_proj" in name: + return "to_v_lora" + elif "k_proj" in name: + return "to_k_lora" + else: + return "to_out_lora" + + def save_attn_procs( + self, + save_directory: Union[str, os.PathLike], + text_encoder_lora_layers: nn.Module, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = False, + **kwargs, + ): + r""" + Save an attention processor to a directory, so that it can be re-loaded using the + [`~loaders.TextEncoderLoRAMixin.load_attn_procs`] method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + text_encoder_lora_layers (`nn.Module`): + LoRA trainable parameters provided as `nn.Module`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + """ + weight_name = weight_name or deprecate( + "weights_name", + "0.18.0", + "`weights_name` is deprecated, please use `weight_name` instead.", + take_from=kwargs, + ) + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + if save_function is None: + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + + else: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + + model_to_save = text_encoder_lora_layers + + # Save the model + state_dict = model_to_save.state_dict() + + if weight_name is None: + if safe_serialization: + weight_name = TEXT_ENCODER_LORA_WEIGHT_NAME_SAFE + else: + weight_name = TEXT_ENCODER_LORA_WEIGHT_NAME + + # Save the model + save_function(state_dict, os.path.join(save_directory, weight_name)) + + logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 615804c91a19..93ccf119c60c 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -30,6 +30,7 @@ ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, + TEXT_ENCODER_TARGET_MODULES, WEIGHTS_NAME, ) from .deprecation_utils import deprecate diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index b9e60a2a873b..1134ba6fb656 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -30,3 +30,4 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] +TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"] diff --git a/tests/test_text_encoder_lora.py b/tests/test_text_encoder_lora.py new file mode 100644 index 000000000000..0ec150e22755 --- /dev/null +++ b/tests/test_text_encoder_lora.py @@ -0,0 +1,101 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import os +import tempfile +import unittest + +import torch +from transformers import CLIPTextConfig, CLIPTextModel + +from diffusers.loaders import TextEncoderLoRAMixin +from diffusers.utils import torch_device + + +class TextEncoderLoRATests(unittest.TestCase): + def get_dummy_components(self): + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config).to(torch_device) + text_encoder_lora_wrapper = TextEncoderLoRAMixin(copy.deepcopy(text_encoder)) + lora_attn_procs = text_encoder_lora_wrapper.lora_attn_procs + text_encoder_lora = text_encoder_lora_wrapper._modify_text_encoder(lora_attn_procs) + return text_encoder, text_encoder_lora, text_encoder_lora_wrapper + + def get_dummy_inputs(self): + batch_size = 1 + sequence_length = 10 + generator = torch.manual_seed(0) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + return input_ids + + def test_lora_default_case(self): + text_encoder, text_encoder_lora, _ = self.get_dummy_components() + inputs = self.get_dummy_inputs() + + with torch.no_grad(): + original_outputs = text_encoder(inputs)[0] + lora_outputs = text_encoder_lora(inputs)[0] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(original_outputs, lora_outputs)) + + def test_lora_save_load(self): + text_encoder, _, text_encoder_lora_wrapper = self.get_dummy_components() + inputs = self.get_dummy_inputs() + + with torch.no_grad(): + original_outputs = text_encoder(inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdirname: + text_encoder_lora_wrapper.save_attn_procs(tmpdirname, text_encoder_lora_wrapper.text_encoder_lora_layers) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_text_encoder_lora_weights.bin"))) + text_encoder_lora = text_encoder_lora_wrapper.load_attn_procs(tmpdirname) + + with torch.no_grad(): + lora_outputs = text_encoder_lora(inputs)[0] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(original_outputs, lora_outputs)) + + def test_lora_save_load_safetensors(self): + text_encoder, _, text_encoder_lora_wrapper = self.get_dummy_components() + inputs = self.get_dummy_inputs() + + with torch.no_grad(): + original_outputs = text_encoder(inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdirname: + text_encoder_lora_wrapper.save_attn_procs( + tmpdirname, text_encoder_lora_wrapper.text_encoder_lora_layers, safe_serialization=True + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_text_encoder_lora_weights.safetensors"))) + text_encoder_lora = text_encoder_lora_wrapper.load_attn_procs(tmpdirname) + + with torch.no_grad(): + lora_outputs = text_encoder_lora(inputs)[0] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(original_outputs, lora_outputs))