diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 40d8804893e2..79274c98659d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -137,7 +137,6 @@ def _unfuse_lora(self): self.w_down = None def forward(self, input): - # print(f"{self.__class__.__name__} has a lora_scale of {self.lora_scale}") if self.lora_scale is None: self.lora_scale = 1.0 if self.lora_linear_layer is None: @@ -1008,19 +1007,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + controlnet=False, **kwargs, ): r""" Return state dict for lora weights and the network alphas. - - - We support loading A1111 formatted LoRA checkpoints in a limited capacity. - - This function is experimental and might change in the future. - - - Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either: @@ -1032,6 +1024,8 @@ def lora_state_dict( - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + controlnet (`bool`, *optional*, defaults to False): + If we're converting a ControlNet LoRA checkpoint. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. @@ -1143,20 +1137,21 @@ def lora_state_dict( state_dict = pretrained_model_name_or_path_or_dict network_alphas = None - if all( - ( - k.startswith("lora_te_") - or k.startswith("lora_unet_") - or k.startswith("lora_te1_") - or k.startswith("lora_te2_") - ) - for k in state_dict.keys() - ): - # Map SDXL blocks correctly. - if unet_config is not None: - # use unet config to remap block numbers - state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) - state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict) + if not controlnet: + if all( + ( + k.startswith("lora_te_") + or k.startswith("lora_unet_") + or k.startswith("lora_te1_") + or k.startswith("lora_te2_") + ) + for k in state_dict.keys() + ): + # Map SDXL blocks correctly. + if unet_config is not None: + # use unet config to remap block numbers + state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) + state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict) return state_dict, network_alphas @@ -1700,7 +1695,6 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") else: diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") - if "middle.block" in diffusers_name: diffusers_name = diffusers_name.replace("middle.block", "mid_block") else: @@ -1835,6 +1829,7 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): te_state_dict.update(te2_state_dict) new_state_dict = {**unet_state_dict, **te_state_dict} + return new_state_dict, network_alphas def unload_lora_weights(self): @@ -2517,3 +2512,105 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): controlnet.to(torch_dtype=torch_dtype) return controlnet + + +class ControlLoRAMixin(LoraLoaderMixin): + # Simplify ControlNet LoRA loading. + def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): + from .models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer + from .pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_unet_checkpoint + + state_dict, _ = self.lora_state_dict(pretrained_model_name_or_path_or_dict, controlnet=True, **kwargs) + controlnet_config = kwargs.pop("controlnet_config", None) + if controlnet_config is None: + raise ValueError("Must provide a `controlnet_config`.") + + # ControlNet LoRA has a mix of things. Some parameters correspond to LoRA and some correspond + # to the ones belonging to the original state_dict (initialized from the underlying UNet). + # So, we first map the LoRA parameters and then we load the remaining state_dict into + # the ControlNet. + converted_state_dict = convert_ldm_unet_checkpoint( + state_dict, controlnet=True, config=controlnet_config, skip_extract_state_dict=True, controlnet_lora=True + ) + + # Load whatever is matching. + load_state_dict_results = self.load_state_dict(converted_state_dict, strict=False) + if not all("lora" in k for k in load_state_dict_results.unexpected_keys): + raise ValueError( + f"The unexpected keys must only belong to LoRA parameters at this point, but found the following keys that are non-LoRA\n: {load_state_dict_results.unexpected_keys}" + ) + + # Filter out the rest of the state_dict for handling LoRA. + remaining_state_dict = { + k: v for k, v in converted_state_dict.items() if k in load_state_dict_results.unexpected_keys + } + + # Handle LoRA. + lora_grouped_dict = defaultdict(dict) + lora_layers_list = [] + + all_keys = list(remaining_state_dict.keys()) + for key in all_keys: + value = remaining_state_dict.pop(key) + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + lora_grouped_dict[attn_processor_key][sub_key] = value + + if len(remaining_state_dict) > 0: + raise ValueError( + f"The `remaining_state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}" + ) + + for key, value_dict in lora_grouped_dict.items(): + attn_processor = self + for sub_key in key.split("."): + attn_processor = getattr(attn_processor, sub_key) + + # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers + # or add_{k,v,q,out_proj}_proj_lora layers. + rank = value_dict["lora.down.weight"].shape[0] + + if isinstance(attn_processor, LoRACompatibleConv): + in_features = attn_processor.in_channels + out_features = attn_processor.out_channels + kernel_size = attn_processor.kernel_size + + lora = LoRAConv2dLayer( + in_features=in_features, + out_features=out_features, + rank=rank, + kernel_size=kernel_size, + stride=attn_processor.stride, + padding=attn_processor.padding, + # initial_weight=attn_processor.weight, + # initial_bias=attn_processor.bias, + ) + elif isinstance(attn_processor, LoRACompatibleLinear): + lora = LoRALinearLayer( + attn_processor.in_features, + attn_processor.out_features, + rank, + # initial_weight=attn_processor.weight, + # initial_bias=attn_processor.bias, + ) + else: + raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.") + + value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} + load_state_dict_results = lora.load_state_dict(value_dict, strict=False) + if not all("initial" in k for k in load_state_dict_results.unexpected_keys): + raise ValueError("Incorrect `value_dict` for the LoRA layer.") + lora_layers_list.append((attn_processor, lora)) + + # set correct dtype & device + lora_layers_list = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in lora_layers_list] + + # set lora layers + for target_module, lora_layer in lora_layers_list: + target_module.set_lora_layer(lora_layer) + + def unload_lora_weights(self): + for _, module in self.named_modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + + # Implement `fuse_lora()` and `unfuse_lora()` (sayakpaul). diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index db05b0689cff..4c8c1d93ab20 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -19,7 +19,8 @@ from torch.nn import functional as F from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import FromOriginalControlnetMixin +from ..loaders import ControlLoRAMixin, FromOriginalControlnetMixin, UNet2DConditionLoadersMixin +from ..models.lora import LoRACompatibleConv from ..utils import BaseOutput, logging from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -80,7 +81,7 @@ def __init__( ): super().__init__() - self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + self.conv_in = LoRACompatibleConv(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) self.blocks = nn.ModuleList([]) @@ -96,6 +97,7 @@ def __init__( def forward(self, conditioning): embedding = self.conv_in(conditioning) + print(f"From conv_in embedding of ControlNet: {embedding[0, :5, :5, -1]}") embedding = F.silu(embedding) for block in self.blocks: @@ -107,7 +109,9 @@ def forward(self, conditioning): return embedding -class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): +class ControlNetModel( + ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, FromOriginalControlnetMixin, ControlLoRAMixin +): """ A ControlNet model. @@ -247,7 +251,7 @@ def __init__( # input conv_in_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 - self.conv_in = nn.Conv2d( + self.conv_in = LoRACompatibleConv( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) @@ -719,6 +723,7 @@ def forward( timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) + print(f"t_emb: {t_emb[0, :3]}") # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. @@ -726,6 +731,8 @@ def forward( t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) + print(f"emb: {emb[0, :3]}") + aug_emb = None if self.class_embedding is not None: @@ -764,6 +771,7 @@ def forward( # 2. pre-process sample = self.conv_in(sample) + print(f"From ControlNet conv_in: {sample[0, :5, :5, -1]}") controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) sample = sample + controlnet_cond diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 3bdd758117cd..d6a919ef5d34 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -18,6 +18,7 @@ import torch from torch import nn +from ..models.lora import LoRACompatibleLinear from .activations import get_activation @@ -166,10 +167,10 @@ def __init__( ): super().__init__() - self.linear_1 = nn.Linear(in_channels, time_embed_dim) + self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim) if cond_proj_dim is not None: - self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + self.cond_proj = LoRACompatibleLinear(cond_proj_dim, in_channels, bias=False) else: self.cond_proj = None @@ -179,7 +180,7 @@ def __init__( time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out) if post_act_fn is None: self.post_act = None diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 834a7051b06d..791a19fb1f10 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -40,7 +40,17 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0): class LoRALinearLayer(nn.Module): - def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): + def __init__( + self, + in_features, + out_features, + rank=4, + network_alpha=None, + device=None, + dtype=None, + # initial_weight=None, + # initial_bias=None, + ): super().__init__() self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) @@ -52,6 +62,10 @@ def __init__(self, in_features, out_features, rank=4, network_alpha=None, device self.out_features = out_features self.in_features = in_features + # # Control-LoRA specific. + # self.initial_weight = initial_weight + # self.initial_bias = initial_bias + nn.init.normal_(self.down.weight, std=1 / rank) nn.init.zeros_(self.up.weight) @@ -66,11 +80,32 @@ def forward(self, hidden_states): up_hidden_states *= self.network_alpha / self.rank return up_hidden_states.to(orig_dtype) + # else: + # initial_weight = self.initial_weight + # if initial_weight.device != hidden_states.device: + # initial_weight = initial_weight.to(hidden_states.device) + # return torch.nn.functional.linear( + # hidden_states.to(dtype), + # initial_weight + # + (torch.mm(self.up.weight.data.flatten(start_dim=1), self.down.weight.data.flatten(start_dim=1))) + # .reshape(self.initial_weight.shape) + # .type(orig_dtype), + # self.initial_bias, + # ) class LoRAConv2dLayer(nn.Module): def __init__( - self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None + self, + in_features, + out_features, + rank=4, + kernel_size=(1, 1), + stride=(1, 1), + padding=0, + network_alpha=None, + # initial_weight=None, + # initial_bias=None, ): super().__init__() @@ -84,6 +119,13 @@ def __init__( self.network_alpha = network_alpha self.rank = rank + # # Control-LoRA specific. + # self.initial_weight = initial_weight + # self.initial_bias = initial_bias + # self.stride = stride + # self.kernel_size = kernel_size + # self.padding = padding + nn.init.normal_(self.down.weight, std=1 / rank) nn.init.zeros_(self.up.weight) @@ -98,6 +140,20 @@ def forward(self, hidden_states): up_hidden_states *= self.network_alpha / self.rank return up_hidden_states.to(orig_dtype) + # else: + # initial_weight = self.initial_weight + # if initial_weight.device != hidden_states.device: + # initial_weight = initial_weight.to(hidden_states.device) + # return torch.nn.functional.conv2d( + # hidden_states, + # initial_weight + # + (torch.mm(self.up.weight.flatten(start_dim=1), self.down.weight.flatten(start_dim=1))) + # .reshape(self.initial_weight.shape) + # .type(orig_dtype), + # self.initial_bias, + # self.stride, + # self.padding, + # ) class LoRACompatibleConv(nn.Conv2d): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 391d58134627..627df5b86b62 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -104,7 +104,10 @@ class StableDiffusionXLControlNetPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, + TextualInversionLoaderMixin, + LoraLoaderMixin, + FromSingleFileMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. @@ -1067,6 +1070,7 @@ def __call__( target_size = target_size or (height, width) add_text_embeds = pooled_prompt_embeds + print(f"pooled_prompt_embeds: {pooled_prompt_embeds.shape}") add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype ) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 22b424bf16f5..9bfbab0c7d7b 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -377,11 +377,21 @@ def create_ldm_bert_config(original_config): def convert_ldm_unet_checkpoint( - checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False + checkpoint, + config, + path=None, + extract_ema=False, + controlnet=False, + skip_extract_state_dict=False, + controlnet_lora=False, ): """ Takes a state dict and a config, and returns a converted checkpoint. """ + if not controlnet and controlnet_lora: + raise ValueError(f"`controlnet_lora` cannot be done with `controlnet` set to {controlnet}.") + if controlnet and controlnet_lora: + skip_extract_state_dict = True if skip_extract_state_dict: unet_state_dict = checkpoint @@ -419,10 +429,22 @@ def convert_ldm_unet_checkpoint( new_checkpoint = {} - new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] - new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] - new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] - new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + if controlnet_lora: + # Safe to pop as it doesn't have anything. + _ = unet_state_dict.pop("lora_controlnet") + + if not controlnet_lora: + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + else: + new_checkpoint["time_embedding.linear_1.lora_down.weight"] = unet_state_dict["time_embed.0.down"] + new_checkpoint["time_embedding.linear_1.lora_up.weight"] = unet_state_dict["time_embed.0.up"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.lora_down.weight"] = unet_state_dict["time_embed.2.down"] + new_checkpoint["time_embedding.linear_2.lora_up.weight"] = unet_state_dict["time_embed.2.up"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] if config["class_embed_type"] is None: # No parameters to port @@ -436,13 +458,26 @@ def convert_ldm_unet_checkpoint( raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") if config["addition_embed_type"] == "text_time": - new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] - new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] - new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] - new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] - - new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] - new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + if not controlnet_lora: + new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + new_checkpoint["add_embedding.linear_1.lora_down.weight"] = unet_state_dict["label_emb.0.0.down"] + new_checkpoint["add_embedding.linear_1.lora_up.weight"] = unet_state_dict["label_emb.0.0.up"] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["add_embedding.linear_2.lora_down.weight"] = unet_state_dict["label_emb.0.2.down"] + new_checkpoint["add_embedding.linear_2.lora_up.weight"] = unet_state_dict["label_emb.0.2.up"] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + + if not controlnet_lora: + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + else: + new_checkpoint["conv_in.lora_down.weight"] = unet_state_dict["input_blocks.0.0.down"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + new_checkpoint["conv_in.lora_up.weight"] = unet_state_dict["input_blocks.0.0.up"] if not controlnet: new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] @@ -588,8 +623,9 @@ def convert_ldm_unet_checkpoint( orig_index += 2 diffusers_index = 0 + diffusers_index_limit = 6 - while diffusers_index < 6: + while diffusers_index < diffusers_index_limit: new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.weight" ) @@ -599,12 +635,13 @@ def convert_ldm_unet_checkpoint( diffusers_index += 1 orig_index += 2 - new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) + if not controlnet_lora: + new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) # down blocks for i in range(num_input_blocks): @@ -615,6 +652,21 @@ def convert_ldm_unet_checkpoint( new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") + if controlnet_lora: + modified_new_checkpoint = {} + down_pattern = r"\.down$" + up_pattern = r"\.up$" + + for key in new_checkpoint: + new_key = key + new_key = re.sub(down_pattern, ".lora.down.weight", new_key) + new_key = re.sub(up_pattern, ".lora.up.weight", new_key) + new_key = new_key.replace("lora_down", "lora.down") + new_key = new_key.replace("lora_up", "lora.up") + modified_new_checkpoint[new_key] = new_checkpoint[key] + + new_checkpoint = modified_new_checkpoint + return new_checkpoint