diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 6d273de5ca9d..b1d3bbb122e6 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -33,6 +33,7 @@ SlicedAttnAddedKVProcessor, XFormersAttnProcessor, ) +from .models.lora import Conv2dWithLoRA, LinearWithLoRA, LoRAConv2dLayer, LoRALinearLayer from .utils import ( DIFFUSERS_CACHE, HF_HUB_OFFLINE, @@ -415,6 +416,37 @@ def save_function(weights, filename): save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") + def _load_lora_aux(self, state_dict, network_alpha=None): + # print("\n".join(sorted(state_dict.keys()))) + lora_grouped_dict = defaultdict(dict) + for key, value in state_dict.items(): + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + lora_grouped_dict[attn_processor_key][sub_key] = value + + for key, value_dict in lora_grouped_dict.items(): + rank = value_dict["lora.down.weight"].shape[0] + hidden_size = value_dict["lora.up.weight"].shape[0] + target_modules = [module for name, module in self.named_modules() if name == key] + if len(target_modules) == 0: + logger.warning(f"Could not find module {key} in the model. Skipping.") + continue + + target_module = target_modules[0] + value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} + + lora = None + if isinstance(target_module, Conv2dWithLoRA): + lora = LoRAConv2dLayer(hidden_size, hidden_size, rank, network_alpha) + elif isinstance(target_module, LinearWithLoRA): + lora = LoRALinearLayer(target_module.in_features, target_module.out_features, rank, network_alpha) + else: + raise ValueError(f"Module {key} is not a Conv2dWithLoRA or LinearWithLoRA module.") + lora.load_state_dict(value_dict) + lora.to(device=self.device, dtype=self.dtype) + + # install lora + target_module.lora_layer = lora + class TextualInversionLoaderMixin: r""" @@ -917,7 +949,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # Convert kohya-ss Style LoRA attn procs to diffusers attn procs network_alpha = None if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()): - state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict) + state_dict, unet_state_dict_aux, te_state_dict_aux, network_alpha = self._convert_kohya_lora_to_diffusers( + state_dict + ) + self.unet._load_lora_aux(unet_state_dict_aux, network_alpha=network_alpha) + self._load_lora_aux_for_text_encoder(te_state_dict_aux, network_alpha=network_alpha) # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as @@ -1282,6 +1318,8 @@ def save_function(weights, filename): def _convert_kohya_lora_to_diffusers(self, state_dict): unet_state_dict = {} te_state_dict = {} + unet_state_dict_aux = {} + te_state_dict_aux = {} network_alpha = None for key, value in state_dict.items(): @@ -1306,12 +1344,20 @@ def _convert_kohya_lora_to_diffusers(self, state_dict): diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") + diffusers_name = diffusers_name.replace("proj.in", "proj_in") + diffusers_name = diffusers_name.replace("proj.out", "proj_out") if "transformer_blocks" in diffusers_name: if "attn1" in diffusers_name or "attn2" in diffusers_name: diffusers_name = diffusers_name.replace("attn1", "attn1.processor") diffusers_name = diffusers_name.replace("attn2", "attn2.processor") unet_state_dict[diffusers_name] = value unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + elif "ff" in diffusers_name: + unet_state_dict_aux[diffusers_name] = value + unet_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + elif any(key in diffusers_name for key in ("proj_in", "proj_out")): + unet_state_dict_aux[diffusers_name] = value + unet_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] elif lora_name.startswith("lora_te_"): diffusers_name = key.replace("lora_te_", "").replace("_", ".") diffusers_name = diffusers_name.replace("text.model", "text_model") @@ -1323,11 +1369,45 @@ def _convert_kohya_lora_to_diffusers(self, state_dict): if "self_attn" in diffusers_name: te_state_dict[diffusers_name] = value te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + elif "mlp" in diffusers_name: + te_state_dict_aux[diffusers_name] = value + te_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} new_state_dict = {**unet_state_dict, **te_state_dict} - return new_state_dict, network_alpha + return new_state_dict, unet_state_dict_aux, te_state_dict_aux, network_alpha + + def _load_lora_aux_for_text_encoder(self, state_dict, network_alpha=None): + # print("\n".join(sorted(state_dict.keys()))) + lora_grouped_dict = defaultdict(dict) + for key, value in state_dict.items(): + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + lora_grouped_dict[attn_processor_key][sub_key] = value + + for key, value_dict in lora_grouped_dict.items(): + rank = value_dict["lora.down.weight"].shape[0] + target_modules = [module for name, module in self.text_encoder.named_modules() if name == key] + if len(target_modules) == 0: + logger.warning(f"Could not find module {key} in the model. Skipping.") + continue + + target_module = target_modules[0] + value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} + lora_layer = LoRALinearLayer(target_module.in_features, target_module.out_features, rank, network_alpha) + lora_layer.load_state_dict(value_dict) + lora_layer.to(device=self.text_encoder.device, dtype=self.text_encoder.dtype) + + old_forward = target_module.forward + + def make_new_forward(old_forward, lora_layer): + def new_forward(x): + return old_forward(x) + lora_layer(x) + + return new_forward + + # Monkey-patch. + target_module.forward = make_new_forward(old_forward, lora_layer) class FromCkptMixin: diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 8805257ebe9a..8476fe31f675 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -21,6 +21,7 @@ from .activations import get_activation from .attention_processor import Attention from .embeddings import CombinedTimestepLabelEmbeddings +from .lora import LinearWithLoRA @maybe_allow_in_graph @@ -222,7 +223,7 @@ def __init__( # project dropout self.net.append(nn.Dropout(dropout)) # project out - self.net.append(nn.Linear(inner_dim, dim_out)) + self.net.append(LinearWithLoRA(inner_dim, dim_out)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: self.net.append(nn.Dropout(dropout)) @@ -266,7 +267,7 @@ class GEGLU(nn.Module): def __init__(self, dim_in: int, dim_out: int): super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) + self.proj = LinearWithLoRA(dim_in, dim_out * 2) def gelu(self, gate): if gate.device.type != "mps": diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e0404a83cc9a..83c8b8043db5 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -19,6 +19,7 @@ from ..utils import deprecate, logging, maybe_allow_in_graph from ..utils.import_utils import is_xformers_available +from .lora import LoRALinearLayer logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -504,36 +505,6 @@ def __call__( return hidden_states -class LoRALinearLayer(nn.Module): - def __init__(self, in_features, out_features, rank=4, network_alpha=None): - super().__init__() - - if rank > min(in_features, out_features): - raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") - - self.down = nn.Linear(in_features, rank, bias=False) - self.up = nn.Linear(rank, out_features, bias=False) - # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. - # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning - self.network_alpha = network_alpha - self.rank = rank - - nn.init.normal_(self.down.weight, std=1 / rank) - nn.init.zeros_(self.up.weight) - - def forward(self, hidden_states): - orig_dtype = hidden_states.dtype - dtype = self.down.weight.dtype - - down_hidden_states = self.down(hidden_states.to(dtype)) - up_hidden_states = self.up(down_hidden_states) - - if self.network_alpha is not None: - up_hidden_states *= self.network_alpha / self.rank - - return up_hidden_states.to(orig_dtype) - - class LoRAAttnProcessor(nn.Module): r""" Processor for implementing the LoRA attention mechanism. diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py new file mode 100644 index 000000000000..2bd43ef8dc17 --- /dev/null +++ b/src/diffusers/models/lora.py @@ -0,0 +1,111 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from torch import nn + + +# moved from attention_processor.py +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4, network_alpha=None): + super().__init__() + + if rank > min(in_features, out_features): + raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") + + self.down = nn.Linear(in_features, rank, bias=False) + self.up = nn.Linear(rank, out_features, bias=False) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + + +# copied from LoRAConv2dLayer +class LoRAConv2dLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4, network_alpha=None): + super().__init__() + + if rank > min(in_features, out_features): + raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") + + self.down = nn.Conv2d(in_features, rank, (1, 1), (1, 1), bias=False) + self.up = nn.Conv2d(rank, out_features, (1, 1), (1, 1), bias=False) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + + +class Conv2dWithLoRA(nn.Conv2d): + """ + A convolutional layer that can be used with LoRA. + """ + + def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs): + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer + + def forward(self, x): + if self.lora_layer is None: + return super().forward(x) + else: + return super().forward(x) + self.lora_layer(x) + + +class LinearWithLoRA(nn.Linear): + """ + A Linear layer that can be used with LoRA. + """ + + def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs): + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer + + def forward(self, x): + if self.lora_layer is None: + return super().forward(x) + else: + return super().forward(x) + self.lora_layer(x) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index ec4cb371845f..ecc66c1027c9 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -23,6 +23,7 @@ from ..utils import BaseOutput, deprecate from .attention import BasicTransformerBlock from .embeddings import PatchEmbed +from .lora import Conv2dWithLoRA from .modeling_utils import ModelMixin @@ -146,7 +147,7 @@ def __init__( if use_linear_projection: self.proj_in = nn.Linear(in_channels, inner_dim) else: - self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + self.proj_in = Conv2dWithLoRA(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" @@ -202,7 +203,7 @@ def __init__( if use_linear_projection: self.proj_out = nn.Linear(inner_dim, in_channels) else: - self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = Conv2dWithLoRA(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)