Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
120 commits
Select commit Hold shift + click to select a range
c7a369a
make controlnet sublcass from a loraloader
sayakpaul Aug 18, 2023
9a78f03
wondering'
sayakpaul Aug 18, 2023
e9fe443
wondering'
sayakpaul Aug 18, 2023
2d4ae00
relax check.
sayakpaul Aug 22, 2023
4932716
exploring
sayakpaul Aug 22, 2023
e736960
sai controlnet
sayakpaul Aug 22, 2023
30dee21
let's see
sayakpaul Aug 22, 2023
6f9e14b
debugging
sayakpaul Aug 22, 2023
2257ba9
debugging
sayakpaul Aug 22, 2023
38fb6fe
debugging
sayakpaul Aug 22, 2023
c8ec943
remove unnecessary statements.
sayakpaul Aug 22, 2023
0709834
simplify condition.
sayakpaul Aug 22, 2023
86515e4
seeing.
sayakpaul Aug 22, 2023
a9dfd86
debugging
sayakpaul Aug 22, 2023
4baa7e3
debugging
sayakpaul Aug 22, 2023
df3dfe3
debugging
sayakpaul Aug 22, 2023
dde7ed6
debugging
sayakpaul Aug 22, 2023
04f663d
debugging
sayakpaul Aug 22, 2023
e47b47d
debugging
sayakpaul Aug 22, 2023
54d1508
successful LoRA state dict parsing.
sayakpaul Aug 22, 2023
6adc8d5
successful LoRA state dict parsing.
sayakpaul Aug 22, 2023
24a2551
debugging
sayakpaul Aug 22, 2023
09003fb
debugging
sayakpaul Aug 22, 2023
8d19bef
debugging
sayakpaul Aug 22, 2023
260d5cc
debugging
sayakpaul Aug 22, 2023
3ad63ea
debugging
sayakpaul Aug 22, 2023
5860478
debugging
sayakpaul Aug 22, 2023
e572736
debugging
sayakpaul Aug 22, 2023
c3e0dd8
debugging
sayakpaul Aug 22, 2023
3924166
debugging
sayakpaul Aug 22, 2023
00fea8a
debugging
sayakpaul Aug 22, 2023
12d7b5d
debugging
sayakpaul Aug 22, 2023
a58abee
debugging
sayakpaul Aug 22, 2023
6295db5
debugging
sayakpaul Aug 22, 2023
ae1a178
debugging
sayakpaul Aug 22, 2023
58c9f98
debugging
sayakpaul Aug 22, 2023
e047c4e
better state dict munging
sayakpaul Aug 22, 2023
4436870
remove print
sayakpaul Aug 22, 2023
50f3f4a
make method a part of it now
sayakpaul Aug 22, 2023
48257fb
fix
sayakpaul Aug 22, 2023
40480de
more stuff
sayakpaul Aug 24, 2023
13dffc3
debugging
sayakpaul Sep 5, 2023
6b6195f
debugging
sayakpaul Sep 5, 2023
7e87bf9
changes
sayakpaul Sep 5, 2023
4c93de5
changes
sayakpaul Sep 5, 2023
182e455
changes
sayakpaul Sep 5, 2023
c13e824
changes
sayakpaul Sep 5, 2023
dc27a08
changes
sayakpaul Sep 5, 2023
e2e5477
changes
sayakpaul Sep 5, 2023
efec092
changes
sayakpaul Sep 5, 2023
e871eee
changes
sayakpaul Sep 5, 2023
9d43c95
changes
sayakpaul Sep 5, 2023
7c26e90
changes
sayakpaul Sep 5, 2023
f9eb243
changes
sayakpaul Sep 5, 2023
000f74c
changes
sayakpaul Sep 5, 2023
101ceeb
changes
sayakpaul Sep 5, 2023
d326f24
changes
sayakpaul Sep 5, 2023
c35161d
changes
sayakpaul Sep 5, 2023
e103f77
changes
sayakpaul Sep 5, 2023
0e42a2c
changes
sayakpaul Sep 5, 2023
5bdb7bb
changes
sayakpaul Sep 5, 2023
e143979
changes
sayakpaul Sep 5, 2023
2baae10
remove unnecessary stuff from loaders.py
sayakpaul Sep 5, 2023
fbb2d7b
Merge branch 'main' into controlnet-sai
sayakpaul Sep 5, 2023
95f09d8
remove unneeded stuff.
sayakpaul Sep 5, 2023
d88c806
better simplicity.
sayakpaul Sep 5, 2023
260bc75
better modularity
sayakpaul Sep 5, 2023
5e5004d
fix: exception raise/.
sayakpaul Sep 5, 2023
11a85cd
empty lora controlnet key
sayakpaul Sep 5, 2023
d166732
empty lora controlnet key
sayakpaul Sep 5, 2023
b3b7798
debugging
sayakpaul Sep 5, 2023
d0e1cfb
debugging
sayakpaul Sep 5, 2023
11ddd6c
debugging
sayakpaul Sep 5, 2023
8f6608d
debugging
sayakpaul Sep 5, 2023
fa4782f
debugging
sayakpaul Sep 5, 2023
aa4f65f
debugging
sayakpaul Sep 5, 2023
e238f3a
debugging
sayakpaul Sep 5, 2023
8206ef0
debugging
sayakpaul Sep 5, 2023
33cfc2d
debugging
sayakpaul Sep 5, 2023
71f3c91
better state_dict munging
sayakpaul Sep 5, 2023
1bfbefb
better state_dict munging
sayakpaul Sep 5, 2023
8ad9b97
better state_dict munging
sayakpaul Sep 5, 2023
d901a9a
sanity
sayakpaul Sep 5, 2023
610be14
sanity
sayakpaul Sep 5, 2023
2027143
sanity
sayakpaul Sep 5, 2023
f7fde8a
fix: embeddings.
sayakpaul Sep 5, 2023
b35f61f
fix: embeddings.
sayakpaul Sep 5, 2023
ebec211
fix: embeddings.
sayakpaul Sep 5, 2023
367e6c0
remove prints.
sayakpaul Sep 5, 2023
dd0ce66
make style
sayakpaul Sep 5, 2023
f17befc
fix: doc
sayakpaul Sep 18, 2023
a66a468
debugging
sayakpaul Sep 18, 2023
9699382
debugging
sayakpaul Sep 18, 2023
70c0c68
debugging
sayakpaul Sep 18, 2023
432fa6b
debugging
sayakpaul Sep 18, 2023
b1099e8
minor clean up
sayakpaul Sep 18, 2023
87ee372
debugging
sayakpaul Sep 18, 2023
05b7f8b
debugging
sayakpaul Sep 18, 2023
e1286db
debugging
sayakpaul Sep 18, 2023
9cfce5f
debugging
sayakpaul Sep 18, 2023
57d52b4
debugging
sayakpaul Sep 19, 2023
8dcc44b
debugging
sayakpaul Sep 19, 2023
a054d80
better support?
sayakpaul Sep 28, 2023
64284b1
make strict loading false
sayakpaul Sep 28, 2023
13e8c87
better conditioning
sayakpaul Sep 28, 2023
b421694
another
sayakpaul Sep 28, 2023
5ceb0a2
log
sayakpaul Sep 28, 2023
567a2de
log
sayakpaul Sep 28, 2023
c6a0406
remove print
sayakpaul Sep 28, 2023
86f5980
change class name
sayakpaul Sep 28, 2023
4087dbf
step by step debug
sayakpaul Oct 9, 2023
ef430bf
step by step debug
sayakpaul Oct 9, 2023
c4ad76e
have t printed.
sayakpaul Oct 9, 2023
bf7afc2
remove dtype of t from commit trail.
sayakpaul Oct 9, 2023
5871ecc
remove dtype of t from commit trail.
sayakpaul Oct 9, 2023
332cbfd
debug
sayakpaul Oct 9, 2023
26662de
debug
sayakpaul Oct 9, 2023
b08a0a6
debug
sayakpaul Oct 9, 2023
ca6895a
debug
sayakpaul Oct 9, 2023
6dc4d69
debug
sayakpaul Oct 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 121 additions & 24 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def _unfuse_lora(self):
self.w_down = None

def forward(self, input):
# print(f"{self.__class__.__name__} has a lora_scale of {self.lora_scale}")
if self.lora_scale is None:
self.lora_scale = 1.0
if self.lora_linear_layer is None:
Expand Down Expand Up @@ -1008,19 +1007,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
controlnet=False,
**kwargs,
):
r"""
Return state dict for lora weights and the network alphas.

<Tip warning={true}>

We support loading A1111 formatted LoRA checkpoints in a limited capacity.

This function is experimental and might change in the future.

</Tip>
Comment on lines -1016 to -1022
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have graduated from this.


Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
Expand All @@ -1032,6 +1024,8 @@ def lora_state_dict(
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).

controlnet (`bool`, *optional*, defaults to False):
If we're converting a ControlNet LoRA checkpoint.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
Expand Down Expand Up @@ -1143,20 +1137,21 @@ def lora_state_dict(
state_dict = pretrained_model_name_or_path_or_dict

network_alphas = None
if all(
(
k.startswith("lora_te_")
or k.startswith("lora_unet_")
or k.startswith("lora_te1_")
or k.startswith("lora_te2_")
)
for k in state_dict.keys()
):
# Map SDXL blocks correctly.
if unet_config is not None:
# use unet config to remap block numbers
state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)
if not controlnet:
if all(
(
k.startswith("lora_te_")
or k.startswith("lora_unet_")
or k.startswith("lora_te1_")
or k.startswith("lora_te2_")
)
for k in state_dict.keys()
):
# Map SDXL blocks correctly.
if unet_config is not None:
# use unet config to remap block numbers
state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)

return state_dict, network_alphas

Expand Down Expand Up @@ -1700,7 +1695,6 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
else:
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")

if "middle.block" in diffusers_name:
diffusers_name = diffusers_name.replace("middle.block", "mid_block")
else:
Expand Down Expand Up @@ -1835,6 +1829,7 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
te_state_dict.update(te2_state_dict)

new_state_dict = {**unet_state_dict, **te_state_dict}

return new_state_dict, network_alphas

def unload_lora_weights(self):
Expand Down Expand Up @@ -2517,3 +2512,105 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
controlnet.to(torch_dtype=torch_dtype)

return controlnet


class ControlLoRAMixin(LoraLoaderMixin):
# Simplify ControlNet LoRA loading.
def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs):
from .models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
from .pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_unet_checkpoint

state_dict, _ = self.lora_state_dict(pretrained_model_name_or_path_or_dict, controlnet=True, **kwargs)
controlnet_config = kwargs.pop("controlnet_config", None)
if controlnet_config is None:
raise ValueError("Must provide a `controlnet_config`.")

# ControlNet LoRA has a mix of things. Some parameters correspond to LoRA and some correspond
# to the ones belonging to the original state_dict (initialized from the underlying UNet).
# So, we first map the LoRA parameters and then we load the remaining state_dict into
# the ControlNet.
converted_state_dict = convert_ldm_unet_checkpoint(
state_dict, controlnet=True, config=controlnet_config, skip_extract_state_dict=True, controlnet_lora=True
)

# Load whatever is matching.
load_state_dict_results = self.load_state_dict(converted_state_dict, strict=False)
if not all("lora" in k for k in load_state_dict_results.unexpected_keys):
raise ValueError(
f"The unexpected keys must only belong to LoRA parameters at this point, but found the following keys that are non-LoRA\n: {load_state_dict_results.unexpected_keys}"
)

# Filter out the rest of the state_dict for handling LoRA.
remaining_state_dict = {
k: v for k, v in converted_state_dict.items() if k in load_state_dict_results.unexpected_keys
}

# Handle LoRA.
lora_grouped_dict = defaultdict(dict)
lora_layers_list = []

all_keys = list(remaining_state_dict.keys())
for key in all_keys:
value = remaining_state_dict.pop(key)
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
lora_grouped_dict[attn_processor_key][sub_key] = value

if len(remaining_state_dict) > 0:
raise ValueError(
f"The `remaining_state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
)

for key, value_dict in lora_grouped_dict.items():
attn_processor = self
for sub_key in key.split("."):
attn_processor = getattr(attn_processor, sub_key)

# Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
# or add_{k,v,q,out_proj}_proj_lora layers.
rank = value_dict["lora.down.weight"].shape[0]

if isinstance(attn_processor, LoRACompatibleConv):
in_features = attn_processor.in_channels
out_features = attn_processor.out_channels
kernel_size = attn_processor.kernel_size

lora = LoRAConv2dLayer(
in_features=in_features,
out_features=out_features,
rank=rank,
kernel_size=kernel_size,
stride=attn_processor.stride,
padding=attn_processor.padding,
# initial_weight=attn_processor.weight,
# initial_bias=attn_processor.bias,
)
elif isinstance(attn_processor, LoRACompatibleLinear):
lora = LoRALinearLayer(
attn_processor.in_features,
attn_processor.out_features,
rank,
# initial_weight=attn_processor.weight,
# initial_bias=attn_processor.bias,
)
else:
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")

value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
load_state_dict_results = lora.load_state_dict(value_dict, strict=False)
if not all("initial" in k for k in load_state_dict_results.unexpected_keys):
raise ValueError("Incorrect `value_dict` for the LoRA layer.")
lora_layers_list.append((attn_processor, lora))

# set correct dtype & device
lora_layers_list = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in lora_layers_list]

# set lora layers
for target_module, lora_layer in lora_layers_list:
target_module.set_lora_layer(lora_layer)

def unload_lora_weights(self):
for _, module in self.named_modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)

# Implement `fuse_lora()` and `unfuse_lora()` (sayakpaul).
16 changes: 12 additions & 4 deletions src/diffusers/models/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from torch.nn import functional as F

from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalControlnetMixin
from ..loaders import ControlLoRAMixin, FromOriginalControlnetMixin, UNet2DConditionLoadersMixin
from ..models.lora import LoRACompatibleConv
from ..utils import BaseOutput, logging
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
Expand Down Expand Up @@ -80,7 +81,7 @@ def __init__(
):
super().__init__()

self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
self.conv_in = LoRACompatibleConv(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)

self.blocks = nn.ModuleList([])

Expand All @@ -96,6 +97,7 @@ def __init__(

def forward(self, conditioning):
embedding = self.conv_in(conditioning)
print(f"From conv_in embedding of ControlNet: {embedding[0, :5, :5, -1]}")
embedding = F.silu(embedding)

for block in self.blocks:
Expand All @@ -107,7 +109,9 @@ def forward(self, conditioning):
return embedding


class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
class ControlNetModel(
ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, FromOriginalControlnetMixin, ControlLoRAMixin
):
"""
A ControlNet model.

Expand Down Expand Up @@ -247,7 +251,7 @@ def __init__(
# input
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
self.conv_in = LoRACompatibleConv(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)

Expand Down Expand Up @@ -719,13 +723,16 @@ def forward(
timesteps = timesteps.expand(sample.shape[0])

t_emb = self.time_proj(timesteps)
print(f"t_emb: {t_emb[0, :3]}")

# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)

emb = self.time_embedding(t_emb, timestep_cond)
print(f"emb: {emb[0, :3]}")

aug_emb = None

if self.class_embedding is not None:
Expand Down Expand Up @@ -764,6 +771,7 @@ def forward(

# 2. pre-process
sample = self.conv_in(sample)
print(f"From ControlNet conv_in: {sample[0, :5, :5, -1]}")

controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample = sample + controlnet_cond
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
from torch import nn

from ..models.lora import LoRACompatibleLinear
from .activations import get_activation


Expand Down Expand Up @@ -166,10 +167,10 @@ def __init__(
):
super().__init__()

self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim)

if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
self.cond_proj = LoRACompatibleLinear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None

Expand All @@ -179,7 +180,7 @@ def __init__(
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out)

if post_act_fn is None:
self.post_act = None
Expand Down
60 changes: 58 additions & 2 deletions src/diffusers/models/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,17 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):


class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
def __init__(
self,
in_features,
out_features,
rank=4,
network_alpha=None,
device=None,
dtype=None,
# initial_weight=None,
# initial_bias=None,
):
super().__init__()

self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
Expand All @@ -52,6 +62,10 @@ def __init__(self, in_features, out_features, rank=4, network_alpha=None, device
self.out_features = out_features
self.in_features = in_features

# # Control-LoRA specific.
# self.initial_weight = initial_weight
# self.initial_bias = initial_bias

nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)

Expand All @@ -66,11 +80,32 @@ def forward(self, hidden_states):
up_hidden_states *= self.network_alpha / self.rank

return up_hidden_states.to(orig_dtype)
# else:
# initial_weight = self.initial_weight
# if initial_weight.device != hidden_states.device:
# initial_weight = initial_weight.to(hidden_states.device)
# return torch.nn.functional.linear(
# hidden_states.to(dtype),
# initial_weight
# + (torch.mm(self.up.weight.data.flatten(start_dim=1), self.down.weight.data.flatten(start_dim=1)))
# .reshape(self.initial_weight.shape)
# .type(orig_dtype),
# self.initial_bias,
# )


class LoRAConv2dLayer(nn.Module):
def __init__(
self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None
self,
in_features,
out_features,
rank=4,
kernel_size=(1, 1),
stride=(1, 1),
padding=0,
network_alpha=None,
# initial_weight=None,
# initial_bias=None,
):
super().__init__()

Expand All @@ -84,6 +119,13 @@ def __init__(
self.network_alpha = network_alpha
self.rank = rank

# # Control-LoRA specific.
# self.initial_weight = initial_weight
# self.initial_bias = initial_bias
# self.stride = stride
# self.kernel_size = kernel_size
# self.padding = padding

nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)

Expand All @@ -98,6 +140,20 @@ def forward(self, hidden_states):
up_hidden_states *= self.network_alpha / self.rank

return up_hidden_states.to(orig_dtype)
# else:
# initial_weight = self.initial_weight
# if initial_weight.device != hidden_states.device:
# initial_weight = initial_weight.to(hidden_states.device)
# return torch.nn.functional.conv2d(
# hidden_states,
# initial_weight
# + (torch.mm(self.up.weight.flatten(start_dim=1), self.down.weight.flatten(start_dim=1)))
# .reshape(self.initial_weight.shape)
# .type(orig_dtype),
# self.initial_bias,
# self.stride,
# self.padding,
# )


class LoRACompatibleConv(nn.Conv2d):
Expand Down
Loading