Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 82 additions & 2 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
SlicedAttnAddedKVProcessor,
XFormersAttnProcessor,
)
from .models.lora import Conv2dWithLoRA, LinearWithLoRA, LoRAConv2dLayer, LoRALinearLayer
from .utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
Expand Down Expand Up @@ -415,6 +416,37 @@ def save_function(weights, filename):
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")

def _load_lora_aux(self, state_dict, network_alpha=None):
# print("\n".join(sorted(state_dict.keys())))
lora_grouped_dict = defaultdict(dict)
for key, value in state_dict.items():
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
lora_grouped_dict[attn_processor_key][sub_key] = value

for key, value_dict in lora_grouped_dict.items():
rank = value_dict["lora.down.weight"].shape[0]
hidden_size = value_dict["lora.up.weight"].shape[0]
target_modules = [module for name, module in self.named_modules() if name == key]
if len(target_modules) == 0:
logger.warning(f"Could not find module {key} in the model. Skipping.")
continue

target_module = target_modules[0]
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}

lora = None
if isinstance(target_module, Conv2dWithLoRA):
lora = LoRAConv2dLayer(hidden_size, hidden_size, rank, network_alpha)
elif isinstance(target_module, LinearWithLoRA):
lora = LoRALinearLayer(target_module.in_features, target_module.out_features, rank, network_alpha)
else:
raise ValueError(f"Module {key} is not a Conv2dWithLoRA or LinearWithLoRA module.")
lora.load_state_dict(value_dict)
lora.to(device=self.device, dtype=self.dtype)

# install lora
target_module.lora_layer = lora


class TextualInversionLoaderMixin:
r"""
Expand Down Expand Up @@ -917,7 +949,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
# Convert kohya-ss Style LoRA attn procs to diffusers attn procs
network_alpha = None
if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict)
state_dict, unet_state_dict_aux, te_state_dict_aux, network_alpha = self._convert_kohya_lora_to_diffusers(
state_dict
)
self.unet._load_lora_aux(unet_state_dict_aux, network_alpha=network_alpha)
self._load_lora_aux_for_text_encoder(te_state_dict_aux, network_alpha=network_alpha)

# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
Expand Down Expand Up @@ -1282,6 +1318,8 @@ def save_function(weights, filename):
def _convert_kohya_lora_to_diffusers(self, state_dict):
unet_state_dict = {}
te_state_dict = {}
unet_state_dict_aux = {}
te_state_dict_aux = {}
network_alpha = None

for key, value in state_dict.items():
Expand All @@ -1306,12 +1344,20 @@ def _convert_kohya_lora_to_diffusers(self, state_dict):
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif "ff" in diffusers_name:
unet_state_dict_aux[diffusers_name] = value
unet_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
unet_state_dict_aux[diffusers_name] = value
unet_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif lora_name.startswith("lora_te_"):
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
Expand All @@ -1323,11 +1369,45 @@ def _convert_kohya_lora_to_diffusers(self, state_dict):
if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = value
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif "mlp" in diffusers_name:
te_state_dict_aux[diffusers_name] = value
te_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]

unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
new_state_dict = {**unet_state_dict, **te_state_dict}
return new_state_dict, network_alpha
return new_state_dict, unet_state_dict_aux, te_state_dict_aux, network_alpha

def _load_lora_aux_for_text_encoder(self, state_dict, network_alpha=None):
# print("\n".join(sorted(state_dict.keys())))
lora_grouped_dict = defaultdict(dict)
for key, value in state_dict.items():
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
lora_grouped_dict[attn_processor_key][sub_key] = value

for key, value_dict in lora_grouped_dict.items():
rank = value_dict["lora.down.weight"].shape[0]
target_modules = [module for name, module in self.text_encoder.named_modules() if name == key]
if len(target_modules) == 0:
logger.warning(f"Could not find module {key} in the model. Skipping.")
continue

target_module = target_modules[0]
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
lora_layer = LoRALinearLayer(target_module.in_features, target_module.out_features, rank, network_alpha)
lora_layer.load_state_dict(value_dict)
lora_layer.to(device=self.text_encoder.device, dtype=self.text_encoder.dtype)

old_forward = target_module.forward

def make_new_forward(old_forward, lora_layer):
def new_forward(x):
return old_forward(x) + lora_layer(x)

return new_forward

# Monkey-patch.
target_module.forward = make_new_forward(old_forward, lora_layer)


class FromCkptMixin:
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .activations import get_activation
from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings
from .lora import LinearWithLoRA


@maybe_allow_in_graph
Expand Down Expand Up @@ -222,7 +223,7 @@ def __init__(
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(nn.Linear(inner_dim, dim_out))
self.net.append(LinearWithLoRA(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
Expand Down Expand Up @@ -266,7 +267,7 @@ class GEGLU(nn.Module):

def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
self.proj = LinearWithLoRA(dim_in, dim_out * 2)

def gelu(self, gate):
if gate.device.type != "mps":
Expand Down
31 changes: 1 addition & 30 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from ..utils import deprecate, logging, maybe_allow_in_graph
from ..utils.import_utils import is_xformers_available
from .lora import LoRALinearLayer


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -504,36 +505,6 @@ def __call__(
return hidden_states


class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
super().__init__()

if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")

self.down = nn.Linear(in_features, rank, bias=False)
self.up = nn.Linear(rank, out_features, bias=False)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha
self.rank = rank

nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)

def forward(self, hidden_states):
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype

down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)

if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank

return up_hidden_states.to(orig_dtype)


class LoRAAttnProcessor(nn.Module):
r"""
Processor for implementing the LoRA attention mechanism.
Expand Down
111 changes: 111 additions & 0 deletions src/diffusers/models/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from torch import nn


# moved from attention_processor.py
class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
super().__init__()

if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")

self.down = nn.Linear(in_features, rank, bias=False)
self.up = nn.Linear(rank, out_features, bias=False)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha
self.rank = rank

nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)

def forward(self, hidden_states):
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype

down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)

if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank

return up_hidden_states.to(orig_dtype)


# copied from LoRAConv2dLayer
class LoRAConv2dLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
super().__init__()

if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")

self.down = nn.Conv2d(in_features, rank, (1, 1), (1, 1), bias=False)
self.up = nn.Conv2d(rank, out_features, (1, 1), (1, 1), bias=False)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha
self.rank = rank

nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)

def forward(self, hidden_states):
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype

down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)

if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank

return up_hidden_states.to(orig_dtype)


class Conv2dWithLoRA(nn.Conv2d):
"""
A convolutional layer that can be used with LoRA.
"""

def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
super().__init__(*args, **kwargs)
self.lora_layer = lora_layer

def forward(self, x):
if self.lora_layer is None:
return super().forward(x)
else:
return super().forward(x) + self.lora_layer(x)


class LinearWithLoRA(nn.Linear):
"""
A Linear layer that can be used with LoRA.
"""

def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
super().__init__(*args, **kwargs)
self.lora_layer = lora_layer

def forward(self, x):
if self.lora_layer is None:
return super().forward(x)
else:
return super().forward(x) + self.lora_layer(x)
5 changes: 3 additions & 2 deletions src/diffusers/models/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ..utils import BaseOutput, deprecate
from .attention import BasicTransformerBlock
from .embeddings import PatchEmbed
from .lora import Conv2dWithLoRA
from .modeling_utils import ModelMixin


Expand Down Expand Up @@ -146,7 +147,7 @@ def __init__(
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
self.proj_in = Conv2dWithLoRA(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
Expand Down Expand Up @@ -202,7 +203,7 @@ def __init__(
if use_linear_projection:
self.proj_out = nn.Linear(inner_dim, in_channels)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = Conv2dWithLoRA(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
Expand Down