diff --git a/docs/source/en/api/loaders.mdx b/docs/source/en/api/loaders.mdx
index 1d55bd03c064..1243930bec9f 100644
--- a/docs/source/en/api/loaders.mdx
+++ b/docs/source/en/api/loaders.mdx
@@ -28,3 +28,7 @@ API to load such adapter neural networks via the [`loaders.py` module](https://g
### UNet2DConditionLoadersMixin
[[autodoc]] loaders.UNet2DConditionLoadersMixin
+
+### TextEncoderLoRAMixin
+
+[[autodoc]] loaders.TextEncoderLoRAMixin
diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py
index d6bb6fde6ac1..87035c38f72f 100644
--- a/src/diffusers/loaders.py
+++ b/src/diffusers/loaders.py
@@ -16,10 +16,18 @@
from typing import Callable, Dict, Union
import torch
+import torch.nn as nn
from .models.attention_processor import LoRAAttnProcessor
from .models.modeling_utils import _get_model_file
-from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging
+from .utils import (
+ DIFFUSERS_CACHE,
+ HF_HUB_OFFLINE,
+ TEXT_ENCODER_TARGET_MODULES,
+ deprecate,
+ is_safetensors_available,
+ logging,
+)
if is_safetensors_available():
@@ -32,6 +40,9 @@
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
+TEXT_ENCODER_LORA_WEIGHT_NAME = "pytorch_text_encoder_lora_weights.bin"
+TEXT_ENCODER_LORA_WEIGHT_NAME_SAFE = "pytorch_text_encoder_lora_weights.safetensors"
+
class AttnProcsLayers(torch.nn.Module):
def __init__(self, state_dict: Dict[str, torch.Tensor]):
@@ -294,3 +305,307 @@ def save_function(weights, filename):
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
+
+
+class TextEncoderLoRAMixin:
+ r"""
+ This class is used for handling the text encoder used in our pipelines with LoRA. The methods of this class are
+ mostly copy-pasted from [`~UNet2DConditionLoadersMixin`]. We couldn't fully reuse the class because we cannot do
+ things like `self.set_attn_processor()`.
+
+ Args:
+ text_encoder (`nn.Module`):
+ The text encoder module underlying a [`~DiffusionPipeline`].
+ """
+
+ def __init__(self, text_encoder: nn.Module):
+ self.text_encoder = text_encoder
+ self.device = text_encoder.device
+ self.dtype = text_encoder.dtype
+ self._initialize_lora_layers()
+
+ def _initialize_lora_layers(self):
+ self.lora_attn_procs = {}
+ for name, module in self.text_encoder.named_modules():
+ if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]):
+ self.lora_attn_procs[name] = LoRAAttnProcessor(
+ hidden_size=module.out_features, cross_attention_dim=None
+ )
+
+ self.text_encoder_lora_layers = AttnProcsLayers(self.lora_attn_procs)
+
+ def load_attn_procs(
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs
+ ) -> nn.Module:
+ r"""
+ Load pretrained attention processor layers into
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). Instead
+ of setting the attention processing layers (as done in [`~UNet2DConditionLoadersMixin.load_attn_procs`]), we
+ use the LoRA attention layers to monkey-patch the forward passes of the attention modules of the
+ `text_encoder`.
+
+
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
+ `./my_model_directory/`.
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
+ huggingface.co or downloaded locally), you can specify the folder name here.
+
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information.
+
+ Returns:
+ `nn.Module`: The text encoder module with the forward passes of its attention modules monkey-patched.
+
+
+
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
+ this method in a firewalled environment.
+
+
+ """
+
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ if use_safetensors and not is_safetensors_available():
+ raise ValueError(
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
+ )
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = is_safetensors_available()
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ model_file = None
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ # Let's first try to load .safetensors weights
+ if (use_safetensors and weight_name is None) or (
+ weight_name is not None and weight_name.endswith(".safetensors")
+ ):
+ try:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name or TEXT_ENCODER_LORA_WEIGHT_NAME_SAFE,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
+ except IOError as e:
+ if not allow_pickle:
+ raise e
+ # try loading non-safetensors weights
+ pass
+ if model_file is None:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name or TEXT_ENCODER_LORA_WEIGHT_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = torch.load(model_file, map_location="cpu")
+ else:
+ state_dict = pretrained_model_name_or_path_or_dict
+
+ # fill attn processors
+ attn_processors = {}
+
+ is_lora = all("lora" in k for k in state_dict.keys())
+
+ if is_lora:
+ lora_grouped_dict = defaultdict(dict)
+ for key, value in state_dict.items():
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
+ lora_grouped_dict[attn_processor_key][sub_key] = value
+
+ for key, value_dict in lora_grouped_dict.items():
+ rank = value_dict["to_k_lora.down.weight"].shape[0]
+ cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
+ hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
+
+ attn_processors[key] = LoRAAttnProcessor(
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
+ )
+ attn_processors[key].load_state_dict(value_dict)
+
+ else:
+ raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
+
+ # set correct dtype & device
+ attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
+
+ return self._modify_text_encoder(attn_processors)
+
+ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]) -> nn.Module:
+ r"""
+ Monkey-patches the forward passes of attention modules of the text encoder.
+
+ Args:
+ attn_processors: Dict[str, `LoRAAttnProcessor`]:
+ A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`].
+ Returns:
+ `nn.Module`: The modified text encoder.
+ """
+ # Loop over the original attention modules.
+ for name, _ in self.text_encoder.named_modules():
+ if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]):
+ # Retrieve the module and its corresponding LoRA processor.
+ module = self.text_encoder.get_submodule(name)
+ # Construct a new function that performs the LoRA merging. We will monkey patch
+ # this forward pass.
+ lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name))
+ old_forward = module.forward
+
+ def new_forward(x):
+ return old_forward(x) + lora_layer(x)
+
+ # Monkey-patch.
+ module.forward = new_forward
+ return self.text_encoder
+
+ def _get_lora_layer_attribute(self, name: str) -> str:
+ if "q_proj" in name:
+ return "to_q_lora"
+ elif "v_proj" in name:
+ return "to_v_lora"
+ elif "k_proj" in name:
+ return "to_k_lora"
+ else:
+ return "to_out_lora"
+
+ def save_attn_procs(
+ self,
+ save_directory: Union[str, os.PathLike],
+ text_encoder_lora_layers: nn.Module,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = False,
+ **kwargs,
+ ):
+ r"""
+ Save an attention processor to a directory, so that it can be re-loaded using the
+ [`~loaders.TextEncoderLoRAMixin.load_attn_procs`] method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ text_encoder_lora_layers (`nn.Module`):
+ LoRA trainable parameters provided as `nn.Module`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful when in distributed training like
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
+ the main process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
+ need to replace `torch.save` by another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ """
+ weight_name = weight_name or deprecate(
+ "weights_name",
+ "0.18.0",
+ "`weights_name` is deprecated, please use `weight_name` instead.",
+ take_from=kwargs,
+ )
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ if save_function is None:
+ if safe_serialization:
+
+ def save_function(weights, filename):
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
+
+ else:
+ save_function = torch.save
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ model_to_save = text_encoder_lora_layers
+
+ # Save the model
+ state_dict = model_to_save.state_dict()
+
+ if weight_name is None:
+ if safe_serialization:
+ weight_name = TEXT_ENCODER_LORA_WEIGHT_NAME_SAFE
+ else:
+ weight_name = TEXT_ENCODER_LORA_WEIGHT_NAME
+
+ # Save the model
+ save_function(state_dict, os.path.join(save_directory, weight_name))
+
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py
index 615804c91a19..93ccf119c60c 100644
--- a/src/diffusers/utils/__init__.py
+++ b/src/diffusers/utils/__init__.py
@@ -30,6 +30,7 @@
ONNX_EXTERNAL_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
+ TEXT_ENCODER_TARGET_MODULES,
WEIGHTS_NAME,
)
from .deprecation_utils import deprecate
diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py
index b9e60a2a873b..1134ba6fb656 100644
--- a/src/diffusers/utils/constants.py
+++ b/src/diffusers/utils/constants.py
@@ -30,3 +30,4 @@
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
+TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"]
diff --git a/tests/test_text_encoder_lora.py b/tests/test_text_encoder_lora.py
new file mode 100644
index 000000000000..0ec150e22755
--- /dev/null
+++ b/tests/test_text_encoder_lora.py
@@ -0,0 +1,101 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import copy
+import os
+import tempfile
+import unittest
+
+import torch
+from transformers import CLIPTextConfig, CLIPTextModel
+
+from diffusers.loaders import TextEncoderLoRAMixin
+from diffusers.utils import torch_device
+
+
+class TextEncoderLoRATests(unittest.TestCase):
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config).to(torch_device)
+ text_encoder_lora_wrapper = TextEncoderLoRAMixin(copy.deepcopy(text_encoder))
+ lora_attn_procs = text_encoder_lora_wrapper.lora_attn_procs
+ text_encoder_lora = text_encoder_lora_wrapper._modify_text_encoder(lora_attn_procs)
+ return text_encoder, text_encoder_lora, text_encoder_lora_wrapper
+
+ def get_dummy_inputs(self):
+ batch_size = 1
+ sequence_length = 10
+ generator = torch.manual_seed(0)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+ return input_ids
+
+ def test_lora_default_case(self):
+ text_encoder, text_encoder_lora, _ = self.get_dummy_components()
+ inputs = self.get_dummy_inputs()
+
+ with torch.no_grad():
+ original_outputs = text_encoder(inputs)[0]
+ lora_outputs = text_encoder_lora(inputs)[0]
+
+ # Outputs shouldn't match.
+ self.assertFalse(torch.allclose(original_outputs, lora_outputs))
+
+ def test_lora_save_load(self):
+ text_encoder, _, text_encoder_lora_wrapper = self.get_dummy_components()
+ inputs = self.get_dummy_inputs()
+
+ with torch.no_grad():
+ original_outputs = text_encoder(inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ text_encoder_lora_wrapper.save_attn_procs(tmpdirname, text_encoder_lora_wrapper.text_encoder_lora_layers)
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_text_encoder_lora_weights.bin")))
+ text_encoder_lora = text_encoder_lora_wrapper.load_attn_procs(tmpdirname)
+
+ with torch.no_grad():
+ lora_outputs = text_encoder_lora(inputs)[0]
+
+ # Outputs shouldn't match.
+ self.assertFalse(torch.allclose(original_outputs, lora_outputs))
+
+ def test_lora_save_load_safetensors(self):
+ text_encoder, _, text_encoder_lora_wrapper = self.get_dummy_components()
+ inputs = self.get_dummy_inputs()
+
+ with torch.no_grad():
+ original_outputs = text_encoder(inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ text_encoder_lora_wrapper.save_attn_procs(
+ tmpdirname, text_encoder_lora_wrapper.text_encoder_lora_layers, safe_serialization=True
+ )
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_text_encoder_lora_weights.safetensors")))
+ text_encoder_lora = text_encoder_lora_wrapper.load_attn_procs(tmpdirname)
+
+ with torch.no_grad():
+ lora_outputs = text_encoder_lora(inputs)[0]
+
+ # Outputs shouldn't match.
+ self.assertFalse(torch.allclose(original_outputs, lora_outputs))