From f174f7531c5e73ab48bece369538e1548ae4012f Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Fri, 9 Jun 2023 00:26:01 +0900 Subject: [PATCH 1/3] add test_kohya_loras_scaffold.py --- tests/test_kohya_loras_scaffold.py | 302 +++++++++++++++++++++++++++++ 1 file changed, 302 insertions(+) create mode 100644 tests/test_kohya_loras_scaffold.py diff --git a/tests/test_kohya_loras_scaffold.py b/tests/test_kohya_loras_scaffold.py new file mode 100644 index 000000000000..0c9501cfd191 --- /dev/null +++ b/tests/test_kohya_loras_scaffold.py @@ -0,0 +1,302 @@ +# +# +# TODO: REMOVE THIS FILE +# This file is intended to be used for initial development of new features. +# +# + +import math + +import numpy as np +import safetensors +import torch +import torch.nn as nn +from PIL import Image + +from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline + + +# modified from https://github.com/kohya-ss/sd-scripts/blob/ad5f318d066c52e5b27306b399bc87e41f2eef2b/networks/lora.py#L17 +class LoRAModule(torch.nn.Module): + def __init__(self, org_module: torch.nn.Module, lora_dim=4, alpha=1.0, multiplier=1.0): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + + if isinstance(org_module, nn.Conv2d): + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + + if isinstance(org_module, nn.Conv2d): + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + if alpha is None or alpha == 0: + self.alpha = self.lora_dim + else: + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + self.register_buffer("alpha", torch.tensor(alpha)) # Treatable as a constant. + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + + def forward(self, x): + scale = self.alpha / self.lora_dim + return self.multiplier * scale * self.lora_up(self.lora_down(x)) + + +class LoRAModuleContainer(torch.nn.Module): + def __init__(self, hooks, state_dict, multiplier): + super().__init__() + self.multiplier = multiplier + + # Create LoRAModule from state_dict information + for key, value in state_dict.items(): + if "lora_down" in key: + lora_name = key.split(".")[0] + lora_dim = value.size()[0] + lora_name_alpha = key.split(".")[0] + ".alpha" + alpha = None + if lora_name_alpha in state_dict: + alpha = state_dict[lora_name_alpha].item() + hook = hooks[lora_name] + lora_module = LoRAModule(hook.orig_module, lora_dim=lora_dim, alpha=alpha, multiplier=multiplier) + self.register_module(lora_name, lora_module) + + # Load whole LoRA weights + self.load_state_dict(state_dict) + + # Register LoRAModule to LoRAHook + for name, module in self.named_modules(): + if module.__class__.__name__ == "LoRAModule": + hook = hooks[name] + hook.append_lora(module) + + @property + def alpha(self): + return self.multiplier + + @alpha.setter + def alpha(self, multiplier): + self.multiplier = multiplier + for name, module in self.named_modules(): + if module.__class__.__name__ == "LoRAModule": + module.multiplier = multiplier + + def remove_from_hooks(self, hooks): + for name, module in self.named_modules(): + if module.__class__.__name__ == "LoRAModule": + hook = hooks[name] + hook.remove_lora(module) + del module + + +class LoRAHook(torch.nn.Module): + """ + replaces forward method of the original Linear, + instead of replacing the original Linear module. + """ + + def __init__(self): + super().__init__() + self.lora_modules = [] + + def install(self, orig_module): + assert not hasattr(self, "orig_module") + self.orig_module = orig_module + self.orig_forward = self.orig_module.forward + self.orig_module.forward = self.forward + + def uninstall(self): + assert hasattr(self, "orig_module") + self.orig_module.forward = self.orig_forward + del self.orig_forward + del self.orig_module + + def append_lora(self, lora_module): + self.lora_modules.append(lora_module) + + def remove_lora(self, lora_module): + self.lora_modules.remove(lora_module) + + def forward(self, x): + if len(self.lora_modules) == 0: + return self.orig_forward(x) + lora = torch.sum(torch.stack([lora(x) for lora in self.lora_modules]), dim=0) + return self.orig_forward(x) + lora + + +class LoRAHookInjector(object): + def __init__(self): + super().__init__() + self.hooks = {} + self.device = None + self.dtype = None + + def _get_target_modules(self, root_module, prefix, target_replace_modules): + target_modules = [] + for name, module in root_module.named_modules(): + if ( + module.__class__.__name__ in target_replace_modules and "transformer_blocks" not in name + ): # to adapt latest diffusers: + for child_name, child_module in module.named_modules(): + is_linear = isinstance(child_module, nn.Linear) + is_conv2d = isinstance(child_module, nn.Conv2d) + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + target_modules.append((lora_name, child_module)) + return target_modules + + def install_hooks(self, pipe): + """Install LoRAHook to the pipe.""" + assert len(self.hooks) == 0 + text_encoder_targets = self._get_target_modules(pipe.text_encoder, "lora_te", ["CLIPAttention", "CLIPMLP"]) + unet_targets = self._get_target_modules(pipe.unet, "lora_unet", ["Transformer2DModel", "Attention"]) + for name, target_module in text_encoder_targets + unet_targets: + hook = LoRAHook() + hook.install(target_module) + self.hooks[name] = hook + # print(name) + + self.device = pipe.device + self.dtype = pipe.unet.dtype + + def uninstall_hooks(self): + """Uninstall LoRAHook from the pipe.""" + for k, v in self.hooks.items(): + v.uninstall() + self.hooks = {} + + def apply_lora(self, filename, alpha=1.0): + """Load LoRA weights and apply LoRA to the pipe.""" + assert len(self.hooks) != 0 + state_dict = safetensors.torch.load_file(filename) + container = LoRAModuleContainer(self.hooks, state_dict, alpha) + container.to(self.device, self.dtype) + return container + + def remove_lora(self, container): + """Remove the individual LoRA from the pipe.""" + container.remove_from_hooks(self.hooks) + + +def install_lora_hook(pipe: DiffusionPipeline): + """Install LoRAHook to the pipe.""" + assert not hasattr(pipe, "lora_injector") + assert not hasattr(pipe, "apply_lora") + assert not hasattr(pipe, "remove_lora") + injector = LoRAHookInjector() + injector.install_hooks(pipe) + pipe.lora_injector = injector + pipe.apply_lora = injector.apply_lora + pipe.remove_lora = injector.remove_lora + + +def uninstall_lora_hook(pipe: DiffusionPipeline): + """Uninstall LoRAHook from the pipe.""" + pipe.lora_injector.uninstall_hooks() + del pipe.lora_injector + del pipe.apply_lora + del pipe.remove_lora + + +def image_grid(imgs, rows, cols): + assert len(imgs) == rows * cols + + w, h = imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + grid_w, grid_h = grid.size + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid + + +if __name__ == "__main__": + torch.cuda.reset_peak_memory_stats() + + pipe = StableDiffusionPipeline.from_pretrained( + "gsdf/Counterfeit-V2.5", torch_dtype=torch.float16, safety_checker=None + ).to("cuda") + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True) + pipe.enable_xformers_memory_efficient_attention() + + prompt = "masterpeace, best quality, highres, 1girl, at dusk" + negative_prompt = ( + "(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), " + "bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2) " + ) + lora_fn = "../stable-diffusion-study/models/lora/light_and_shadow.safetensors" + + # Without Lora + images = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=512, + height=768, + num_inference_steps=15, + num_images_per_prompt=4, + generator=torch.manual_seed(0), + ).images + image_grid(images, 1, 4).save("test_orig.png") + + mem_bytes = torch.cuda.max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + print(f"Without Lora -> {mem_bytes/(10**6)}MB") + + # Hook version (some restricted apply) + install_lora_hook(pipe) + pipe.apply_lora(lora_fn) + images = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=512, + height=768, + num_inference_steps=15, + num_images_per_prompt=4, + generator=torch.manual_seed(0), + ).images + image_grid(images, 1, 4).save("test_lora_hook.png") + uninstall_lora_hook(pipe) + + mem_bytes = torch.cuda.max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + print(f"Hook version -> {mem_bytes/(10**6)}MB") + + # Diffusers dev version + pipe.load_lora_weights(lora_fn) + images = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=512, + height=768, + num_inference_steps=15, + num_images_per_prompt=4, + generator=torch.manual_seed(0), + # cross_attention_kwargs={"scale": 0.5}, # lora scale + ).images + image_grid(images, 1, 4).save("test_lora_dev.png") + + mem_bytes = torch.cuda.max_memory_allocated() + print(f"Diffusers dev version -> {mem_bytes/(10**6)}MB") + + # abs-difference image + image_hook = np.array(Image.open("test_lora_hook.png"), dtype=np.int16) + image_dev = np.array(Image.open("test_lora_dev.png"), dtype=np.int16) + image_diff = Image.fromarray(np.abs(image_hook - image_dev).astype(np.uint8)) + image_diff.save("test_lora_hook_dev_diff.png") From 199195681ee075135de000ec6de5e3776f9df8a4 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Fri, 9 Jun 2023 00:45:21 +0900 Subject: [PATCH 2/3] merge kohya-lora-loader-perfection --- src/diffusers/loaders.py | 84 ++++++++++++++- src/diffusers/models/attention.py | 5 +- src/diffusers/models/attention_processor.py | 31 +----- src/diffusers/models/lora.py | 111 ++++++++++++++++++++ src/diffusers/models/transformer_2d.py | 5 +- 5 files changed, 200 insertions(+), 36 deletions(-) create mode 100644 src/diffusers/models/lora.py diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 6d273de5ca9d..b1d3bbb122e6 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -33,6 +33,7 @@ SlicedAttnAddedKVProcessor, XFormersAttnProcessor, ) +from .models.lora import Conv2dWithLoRA, LinearWithLoRA, LoRAConv2dLayer, LoRALinearLayer from .utils import ( DIFFUSERS_CACHE, HF_HUB_OFFLINE, @@ -415,6 +416,37 @@ def save_function(weights, filename): save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") + def _load_lora_aux(self, state_dict, network_alpha=None): + # print("\n".join(sorted(state_dict.keys()))) + lora_grouped_dict = defaultdict(dict) + for key, value in state_dict.items(): + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + lora_grouped_dict[attn_processor_key][sub_key] = value + + for key, value_dict in lora_grouped_dict.items(): + rank = value_dict["lora.down.weight"].shape[0] + hidden_size = value_dict["lora.up.weight"].shape[0] + target_modules = [module for name, module in self.named_modules() if name == key] + if len(target_modules) == 0: + logger.warning(f"Could not find module {key} in the model. Skipping.") + continue + + target_module = target_modules[0] + value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} + + lora = None + if isinstance(target_module, Conv2dWithLoRA): + lora = LoRAConv2dLayer(hidden_size, hidden_size, rank, network_alpha) + elif isinstance(target_module, LinearWithLoRA): + lora = LoRALinearLayer(target_module.in_features, target_module.out_features, rank, network_alpha) + else: + raise ValueError(f"Module {key} is not a Conv2dWithLoRA or LinearWithLoRA module.") + lora.load_state_dict(value_dict) + lora.to(device=self.device, dtype=self.dtype) + + # install lora + target_module.lora_layer = lora + class TextualInversionLoaderMixin: r""" @@ -917,7 +949,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # Convert kohya-ss Style LoRA attn procs to diffusers attn procs network_alpha = None if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()): - state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict) + state_dict, unet_state_dict_aux, te_state_dict_aux, network_alpha = self._convert_kohya_lora_to_diffusers( + state_dict + ) + self.unet._load_lora_aux(unet_state_dict_aux, network_alpha=network_alpha) + self._load_lora_aux_for_text_encoder(te_state_dict_aux, network_alpha=network_alpha) # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as @@ -1282,6 +1318,8 @@ def save_function(weights, filename): def _convert_kohya_lora_to_diffusers(self, state_dict): unet_state_dict = {} te_state_dict = {} + unet_state_dict_aux = {} + te_state_dict_aux = {} network_alpha = None for key, value in state_dict.items(): @@ -1306,12 +1344,20 @@ def _convert_kohya_lora_to_diffusers(self, state_dict): diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") + diffusers_name = diffusers_name.replace("proj.in", "proj_in") + diffusers_name = diffusers_name.replace("proj.out", "proj_out") if "transformer_blocks" in diffusers_name: if "attn1" in diffusers_name or "attn2" in diffusers_name: diffusers_name = diffusers_name.replace("attn1", "attn1.processor") diffusers_name = diffusers_name.replace("attn2", "attn2.processor") unet_state_dict[diffusers_name] = value unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + elif "ff" in diffusers_name: + unet_state_dict_aux[diffusers_name] = value + unet_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + elif any(key in diffusers_name for key in ("proj_in", "proj_out")): + unet_state_dict_aux[diffusers_name] = value + unet_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] elif lora_name.startswith("lora_te_"): diffusers_name = key.replace("lora_te_", "").replace("_", ".") diffusers_name = diffusers_name.replace("text.model", "text_model") @@ -1323,11 +1369,45 @@ def _convert_kohya_lora_to_diffusers(self, state_dict): if "self_attn" in diffusers_name: te_state_dict[diffusers_name] = value te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + elif "mlp" in diffusers_name: + te_state_dict_aux[diffusers_name] = value + te_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} new_state_dict = {**unet_state_dict, **te_state_dict} - return new_state_dict, network_alpha + return new_state_dict, unet_state_dict_aux, te_state_dict_aux, network_alpha + + def _load_lora_aux_for_text_encoder(self, state_dict, network_alpha=None): + # print("\n".join(sorted(state_dict.keys()))) + lora_grouped_dict = defaultdict(dict) + for key, value in state_dict.items(): + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + lora_grouped_dict[attn_processor_key][sub_key] = value + + for key, value_dict in lora_grouped_dict.items(): + rank = value_dict["lora.down.weight"].shape[0] + target_modules = [module for name, module in self.text_encoder.named_modules() if name == key] + if len(target_modules) == 0: + logger.warning(f"Could not find module {key} in the model. Skipping.") + continue + + target_module = target_modules[0] + value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} + lora_layer = LoRALinearLayer(target_module.in_features, target_module.out_features, rank, network_alpha) + lora_layer.load_state_dict(value_dict) + lora_layer.to(device=self.text_encoder.device, dtype=self.text_encoder.dtype) + + old_forward = target_module.forward + + def make_new_forward(old_forward, lora_layer): + def new_forward(x): + return old_forward(x) + lora_layer(x) + + return new_forward + + # Monkey-patch. + target_module.forward = make_new_forward(old_forward, lora_layer) class FromCkptMixin: diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 8805257ebe9a..8476fe31f675 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -21,6 +21,7 @@ from .activations import get_activation from .attention_processor import Attention from .embeddings import CombinedTimestepLabelEmbeddings +from .lora import LinearWithLoRA @maybe_allow_in_graph @@ -222,7 +223,7 @@ def __init__( # project dropout self.net.append(nn.Dropout(dropout)) # project out - self.net.append(nn.Linear(inner_dim, dim_out)) + self.net.append(LinearWithLoRA(inner_dim, dim_out)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: self.net.append(nn.Dropout(dropout)) @@ -266,7 +267,7 @@ class GEGLU(nn.Module): def __init__(self, dim_in: int, dim_out: int): super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) + self.proj = LinearWithLoRA(dim_in, dim_out * 2) def gelu(self, gate): if gate.device.type != "mps": diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e0404a83cc9a..83c8b8043db5 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -19,6 +19,7 @@ from ..utils import deprecate, logging, maybe_allow_in_graph from ..utils.import_utils import is_xformers_available +from .lora import LoRALinearLayer logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -504,36 +505,6 @@ def __call__( return hidden_states -class LoRALinearLayer(nn.Module): - def __init__(self, in_features, out_features, rank=4, network_alpha=None): - super().__init__() - - if rank > min(in_features, out_features): - raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") - - self.down = nn.Linear(in_features, rank, bias=False) - self.up = nn.Linear(rank, out_features, bias=False) - # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. - # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning - self.network_alpha = network_alpha - self.rank = rank - - nn.init.normal_(self.down.weight, std=1 / rank) - nn.init.zeros_(self.up.weight) - - def forward(self, hidden_states): - orig_dtype = hidden_states.dtype - dtype = self.down.weight.dtype - - down_hidden_states = self.down(hidden_states.to(dtype)) - up_hidden_states = self.up(down_hidden_states) - - if self.network_alpha is not None: - up_hidden_states *= self.network_alpha / self.rank - - return up_hidden_states.to(orig_dtype) - - class LoRAAttnProcessor(nn.Module): r""" Processor for implementing the LoRA attention mechanism. diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py new file mode 100644 index 000000000000..2bd43ef8dc17 --- /dev/null +++ b/src/diffusers/models/lora.py @@ -0,0 +1,111 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from torch import nn + + +# moved from attention_processor.py +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4, network_alpha=None): + super().__init__() + + if rank > min(in_features, out_features): + raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") + + self.down = nn.Linear(in_features, rank, bias=False) + self.up = nn.Linear(rank, out_features, bias=False) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + + +# copied from LoRAConv2dLayer +class LoRAConv2dLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4, network_alpha=None): + super().__init__() + + if rank > min(in_features, out_features): + raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") + + self.down = nn.Conv2d(in_features, rank, (1, 1), (1, 1), bias=False) + self.up = nn.Conv2d(rank, out_features, (1, 1), (1, 1), bias=False) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + + +class Conv2dWithLoRA(nn.Conv2d): + """ + A convolutional layer that can be used with LoRA. + """ + + def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs): + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer + + def forward(self, x): + if self.lora_layer is None: + return super().forward(x) + else: + return super().forward(x) + self.lora_layer(x) + + +class LinearWithLoRA(nn.Linear): + """ + A Linear layer that can be used with LoRA. + """ + + def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs): + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer + + def forward(self, x): + if self.lora_layer is None: + return super().forward(x) + else: + return super().forward(x) + self.lora_layer(x) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index ec4cb371845f..ecc66c1027c9 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -23,6 +23,7 @@ from ..utils import BaseOutput, deprecate from .attention import BasicTransformerBlock from .embeddings import PatchEmbed +from .lora import Conv2dWithLoRA from .modeling_utils import ModelMixin @@ -146,7 +147,7 @@ def __init__( if use_linear_projection: self.proj_in = nn.Linear(in_channels, inner_dim) else: - self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + self.proj_in = Conv2dWithLoRA(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" @@ -202,7 +203,7 @@ def __init__( if use_linear_projection: self.proj_out = nn.Linear(inner_dim, in_channels) else: - self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = Conv2dWithLoRA(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) From b6179814f4bfb0054ed063b86ef79edf2668c87b Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Tue, 13 Jun 2023 02:04:52 +0900 Subject: [PATCH 3/3] Revert "add test_kohya_loras_scaffold.py" This reverts commit f174f7531c5e73ab48bece369538e1548ae4012f. --- tests/test_kohya_loras_scaffold.py | 302 ----------------------------- 1 file changed, 302 deletions(-) delete mode 100644 tests/test_kohya_loras_scaffold.py diff --git a/tests/test_kohya_loras_scaffold.py b/tests/test_kohya_loras_scaffold.py deleted file mode 100644 index 0c9501cfd191..000000000000 --- a/tests/test_kohya_loras_scaffold.py +++ /dev/null @@ -1,302 +0,0 @@ -# -# -# TODO: REMOVE THIS FILE -# This file is intended to be used for initial development of new features. -# -# - -import math - -import numpy as np -import safetensors -import torch -import torch.nn as nn -from PIL import Image - -from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline - - -# modified from https://github.com/kohya-ss/sd-scripts/blob/ad5f318d066c52e5b27306b399bc87e41f2eef2b/networks/lora.py#L17 -class LoRAModule(torch.nn.Module): - def __init__(self, org_module: torch.nn.Module, lora_dim=4, alpha=1.0, multiplier=1.0): - """if alpha == 0 or None, alpha is rank (no scaling).""" - super().__init__() - - if isinstance(org_module, nn.Conv2d): - in_dim = org_module.in_channels - out_dim = org_module.out_channels - else: - in_dim = org_module.in_features - out_dim = org_module.out_features - - self.lora_dim = lora_dim - - if isinstance(org_module, nn.Conv2d): - kernel_size = org_module.kernel_size - stride = org_module.stride - padding = org_module.padding - self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) - self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) - else: - self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) - self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) - - if alpha is None or alpha == 0: - self.alpha = self.lora_dim - else: - if type(alpha) == torch.Tensor: - alpha = alpha.detach().float().numpy() # without casting, bf16 causes error - self.register_buffer("alpha", torch.tensor(alpha)) # Treatable as a constant. - - # same as microsoft's - torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) - torch.nn.init.zeros_(self.lora_up.weight) - - self.multiplier = multiplier - - def forward(self, x): - scale = self.alpha / self.lora_dim - return self.multiplier * scale * self.lora_up(self.lora_down(x)) - - -class LoRAModuleContainer(torch.nn.Module): - def __init__(self, hooks, state_dict, multiplier): - super().__init__() - self.multiplier = multiplier - - # Create LoRAModule from state_dict information - for key, value in state_dict.items(): - if "lora_down" in key: - lora_name = key.split(".")[0] - lora_dim = value.size()[0] - lora_name_alpha = key.split(".")[0] + ".alpha" - alpha = None - if lora_name_alpha in state_dict: - alpha = state_dict[lora_name_alpha].item() - hook = hooks[lora_name] - lora_module = LoRAModule(hook.orig_module, lora_dim=lora_dim, alpha=alpha, multiplier=multiplier) - self.register_module(lora_name, lora_module) - - # Load whole LoRA weights - self.load_state_dict(state_dict) - - # Register LoRAModule to LoRAHook - for name, module in self.named_modules(): - if module.__class__.__name__ == "LoRAModule": - hook = hooks[name] - hook.append_lora(module) - - @property - def alpha(self): - return self.multiplier - - @alpha.setter - def alpha(self, multiplier): - self.multiplier = multiplier - for name, module in self.named_modules(): - if module.__class__.__name__ == "LoRAModule": - module.multiplier = multiplier - - def remove_from_hooks(self, hooks): - for name, module in self.named_modules(): - if module.__class__.__name__ == "LoRAModule": - hook = hooks[name] - hook.remove_lora(module) - del module - - -class LoRAHook(torch.nn.Module): - """ - replaces forward method of the original Linear, - instead of replacing the original Linear module. - """ - - def __init__(self): - super().__init__() - self.lora_modules = [] - - def install(self, orig_module): - assert not hasattr(self, "orig_module") - self.orig_module = orig_module - self.orig_forward = self.orig_module.forward - self.orig_module.forward = self.forward - - def uninstall(self): - assert hasattr(self, "orig_module") - self.orig_module.forward = self.orig_forward - del self.orig_forward - del self.orig_module - - def append_lora(self, lora_module): - self.lora_modules.append(lora_module) - - def remove_lora(self, lora_module): - self.lora_modules.remove(lora_module) - - def forward(self, x): - if len(self.lora_modules) == 0: - return self.orig_forward(x) - lora = torch.sum(torch.stack([lora(x) for lora in self.lora_modules]), dim=0) - return self.orig_forward(x) + lora - - -class LoRAHookInjector(object): - def __init__(self): - super().__init__() - self.hooks = {} - self.device = None - self.dtype = None - - def _get_target_modules(self, root_module, prefix, target_replace_modules): - target_modules = [] - for name, module in root_module.named_modules(): - if ( - module.__class__.__name__ in target_replace_modules and "transformer_blocks" not in name - ): # to adapt latest diffusers: - for child_name, child_module in module.named_modules(): - is_linear = isinstance(child_module, nn.Linear) - is_conv2d = isinstance(child_module, nn.Conv2d) - if is_linear or is_conv2d: - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") - target_modules.append((lora_name, child_module)) - return target_modules - - def install_hooks(self, pipe): - """Install LoRAHook to the pipe.""" - assert len(self.hooks) == 0 - text_encoder_targets = self._get_target_modules(pipe.text_encoder, "lora_te", ["CLIPAttention", "CLIPMLP"]) - unet_targets = self._get_target_modules(pipe.unet, "lora_unet", ["Transformer2DModel", "Attention"]) - for name, target_module in text_encoder_targets + unet_targets: - hook = LoRAHook() - hook.install(target_module) - self.hooks[name] = hook - # print(name) - - self.device = pipe.device - self.dtype = pipe.unet.dtype - - def uninstall_hooks(self): - """Uninstall LoRAHook from the pipe.""" - for k, v in self.hooks.items(): - v.uninstall() - self.hooks = {} - - def apply_lora(self, filename, alpha=1.0): - """Load LoRA weights and apply LoRA to the pipe.""" - assert len(self.hooks) != 0 - state_dict = safetensors.torch.load_file(filename) - container = LoRAModuleContainer(self.hooks, state_dict, alpha) - container.to(self.device, self.dtype) - return container - - def remove_lora(self, container): - """Remove the individual LoRA from the pipe.""" - container.remove_from_hooks(self.hooks) - - -def install_lora_hook(pipe: DiffusionPipeline): - """Install LoRAHook to the pipe.""" - assert not hasattr(pipe, "lora_injector") - assert not hasattr(pipe, "apply_lora") - assert not hasattr(pipe, "remove_lora") - injector = LoRAHookInjector() - injector.install_hooks(pipe) - pipe.lora_injector = injector - pipe.apply_lora = injector.apply_lora - pipe.remove_lora = injector.remove_lora - - -def uninstall_lora_hook(pipe: DiffusionPipeline): - """Uninstall LoRAHook from the pipe.""" - pipe.lora_injector.uninstall_hooks() - del pipe.lora_injector - del pipe.apply_lora - del pipe.remove_lora - - -def image_grid(imgs, rows, cols): - assert len(imgs) == rows * cols - - w, h = imgs[0].size - grid = Image.new("RGB", size=(cols * w, rows * h)) - grid_w, grid_h = grid.size - - for i, img in enumerate(imgs): - grid.paste(img, box=(i % cols * w, i // cols * h)) - return grid - - -if __name__ == "__main__": - torch.cuda.reset_peak_memory_stats() - - pipe = StableDiffusionPipeline.from_pretrained( - "gsdf/Counterfeit-V2.5", torch_dtype=torch.float16, safety_checker=None - ).to("cuda") - pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True) - pipe.enable_xformers_memory_efficient_attention() - - prompt = "masterpeace, best quality, highres, 1girl, at dusk" - negative_prompt = ( - "(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), " - "bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2) " - ) - lora_fn = "../stable-diffusion-study/models/lora/light_and_shadow.safetensors" - - # Without Lora - images = pipe( - prompt=prompt, - negative_prompt=negative_prompt, - width=512, - height=768, - num_inference_steps=15, - num_images_per_prompt=4, - generator=torch.manual_seed(0), - ).images - image_grid(images, 1, 4).save("test_orig.png") - - mem_bytes = torch.cuda.max_memory_allocated() - torch.cuda.reset_peak_memory_stats() - print(f"Without Lora -> {mem_bytes/(10**6)}MB") - - # Hook version (some restricted apply) - install_lora_hook(pipe) - pipe.apply_lora(lora_fn) - images = pipe( - prompt=prompt, - negative_prompt=negative_prompt, - width=512, - height=768, - num_inference_steps=15, - num_images_per_prompt=4, - generator=torch.manual_seed(0), - ).images - image_grid(images, 1, 4).save("test_lora_hook.png") - uninstall_lora_hook(pipe) - - mem_bytes = torch.cuda.max_memory_allocated() - torch.cuda.reset_peak_memory_stats() - print(f"Hook version -> {mem_bytes/(10**6)}MB") - - # Diffusers dev version - pipe.load_lora_weights(lora_fn) - images = pipe( - prompt=prompt, - negative_prompt=negative_prompt, - width=512, - height=768, - num_inference_steps=15, - num_images_per_prompt=4, - generator=torch.manual_seed(0), - # cross_attention_kwargs={"scale": 0.5}, # lora scale - ).images - image_grid(images, 1, 4).save("test_lora_dev.png") - - mem_bytes = torch.cuda.max_memory_allocated() - print(f"Diffusers dev version -> {mem_bytes/(10**6)}MB") - - # abs-difference image - image_hook = np.array(Image.open("test_lora_hook.png"), dtype=np.int16) - image_dev = np.array(Image.open("test_lora_dev.png"), dtype=np.int16) - image_diff = Image.fromarray(np.abs(image_hook - image_dev).astype(np.uint8)) - image_diff.save("test_lora_hook_dev_diff.png")