Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def text_encoder_attn_modules(text_encoder):

if is_transformers_available():
_import_structure["single_file"].extend(["FromSingleFileMixin"])
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"]
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin","ControlLoRAMixin"]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"]

Expand All @@ -73,7 +73,7 @@ def text_encoder_attn_modules(text_encoder):

if is_transformers_available():
from .ip_adapter import IPAdapterMixin
from .lora import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin
from .lora import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, ControlLoRAMixin
from .single_file import FromSingleFileMixin
from .textual_inversion import TextualInversionLoaderMixin
else:
Expand Down
133 changes: 119 additions & 14 deletions src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def load_lora_weights(
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
controlnet: bool=False,
**kwargs,
):
r"""
Expand Down Expand Up @@ -310,20 +311,21 @@ def lora_state_dict(

network_alphas = None
# TODO: replace it with a method from `state_dict_utils`
if all(
(
k.startswith("lora_te_")
or k.startswith("lora_unet_")
or k.startswith("lora_te1_")
or k.startswith("lora_te2_")
)
for k in state_dict.keys()
):
# Map SDXL blocks correctly.
if unet_config is not None:
# use unet config to remap block numbers
state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)
if not controlnet:
if all(
(
k.startswith("lora_te_")
or k.startswith("lora_unet_")
or k.startswith("lora_te1_")
or k.startswith("lora_te2_")
)
for k in state_dict.keys()
):
# Map SDXL blocks correctly.
if unet_config is not None:
# use unet config to remap block numbers
state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)

return state_dict, network_alphas

Expand Down Expand Up @@ -1867,3 +1869,106 @@ def _remove_text_encoder_monkey_patch(self):
else:
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)


class ControlLoRAMixin(LoraLoaderMixin):
# Simplify ControlNet LoRA loading.
def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs):
from ..models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
from ..pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_unet_checkpoint

state_dict, _ = self.lora_state_dict(pretrained_model_name_or_path_or_dict, controlnet=True, **kwargs)

controlnet_config = kwargs.pop("controlnet_config", None)
if controlnet_config is None:
raise ValueError("Must provide a `controlnet_config`.")

# ControlNet LoRA has a mix of things. Some parameters correspond to LoRA and some correspond
# to the ones belonging to the original state_dict (initialized from the underlying UNet).
# So, we first map the LoRA parameters and then we load the remaining state_dict into
# the ControlNet.
converted_state_dict = convert_ldm_unet_checkpoint(
state_dict, controlnet=True, config=controlnet_config, skip_extract_state_dict=True, controlnet_lora=True
)

# Load whatever is matching.
load_state_dict_results = self.load_state_dict(converted_state_dict, strict=False)
if not all("lora" in k for k in load_state_dict_results.unexpected_keys):
raise ValueError(
f"The unexpected keys must only belong to LoRA parameters at this point, but found the following keys that are non-LoRA\n: {load_state_dict_results.unexpected_keys}"
)

# Filter out the rest of the state_dict for handling LoRA.
remaining_state_dict = {
k: v for k, v in converted_state_dict.items() if k in load_state_dict_results.unexpected_keys
}

# Handle LoRA.
lora_grouped_dict = defaultdict(dict)
lora_layers_list = []

all_keys = list(remaining_state_dict.keys())
for key in all_keys:
value = remaining_state_dict.pop(key)
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
lora_grouped_dict[attn_processor_key][sub_key] = value

if len(remaining_state_dict) > 0:
raise ValueError(
f"The `remaining_state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
)

for key, value_dict in lora_grouped_dict.items():
attn_processor = self
for sub_key in key.split("."):
attn_processor = getattr(attn_processor, sub_key)

# Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
# or add_{k,v,q,out_proj}_proj_lora layers.
rank = value_dict["lora.down.weight"].shape[0]

if isinstance(attn_processor, LoRACompatibleConv):
in_features = attn_processor.in_channels
out_features = attn_processor.out_channels
kernel_size = attn_processor.kernel_size

lora = LoRAConv2dLayer(
in_features=in_features,
out_features=out_features,
rank=rank,
kernel_size=kernel_size,
stride=attn_processor.stride,
padding=attn_processor.padding,
# initial_weight=attn_processor.weight,
# initial_bias=attn_processor.bias,
)
elif isinstance(attn_processor, LoRACompatibleLinear):
lora = LoRALinearLayer(
attn_processor.in_features,
attn_processor.out_features,
rank,
# initial_weight=attn_processor.weight,
# initial_bias=attn_processor.bias,
)
else:
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")

value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
load_state_dict_results = lora.load_state_dict(value_dict, strict=False)
if not all("initial" in k for k in load_state_dict_results.unexpected_keys):
raise ValueError("Incorrect `value_dict` for the LoRA layer.")
lora_layers_list.append((attn_processor, lora))

# set correct dtype & device
lora_layers_list = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in lora_layers_list]

# set lora layers
for target_module, lora_layer in lora_layers_list:
target_module.set_lora_layer(lora_layer)

def unload_lora_weights(self):
for _, module in self.named_modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)

# Implement `fuse_lora()` and `unfuse_lora()` (sayakpaul).
11 changes: 7 additions & 4 deletions src/diffusers/models/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalControlnetMixin
from ..loaders import ControlLoRAMixin, FromOriginalControlnetMixin, UNet2DConditionLoadersMixin
from ..models.lora import LoRACompatibleConv
from ..utils import BaseOutput, logging
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
Expand Down Expand Up @@ -80,7 +82,7 @@ def __init__(
):
super().__init__()

self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
self.conv_in = LoRACompatibleConv(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)

self.blocks = nn.ModuleList([])

Expand All @@ -106,8 +108,9 @@ def forward(self, conditioning):

return embedding


class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
class ControlNetModel(
ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, FromOriginalControlnetMixin, ControlLoRAMixin
):
"""
A ControlNet model.

Expand Down Expand Up @@ -250,7 +253,7 @@ def __init__(
# input
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
self.conv_in = LoRACompatibleConv(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)

Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
from torch import nn

from ..models.lora import LoRACompatibleLinear
from ..utils import USE_PEFT_BACKEND
from .activations import get_activation
from .lora import LoRACompatibleLinear
Expand Down Expand Up @@ -200,10 +201,10 @@ def __init__(
super().__init__()
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear

self.linear_1 = linear_cls(in_channels, time_embed_dim)
self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim)

if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
self.cond_proj = LoRACompatibleLinear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None

Expand All @@ -213,7 +214,7 @@ def __init__(
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out)
self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out)

if post_act_fn is None:
self.post_act = None
Expand Down
Loading