From 99e35de79180d63cfd85263ceb766617d00d5386 Mon Sep 17 00:00:00 2001 From: William Berman Date: Tue, 13 Jun 2023 14:22:14 -0700 Subject: [PATCH 1/3] refactor to support patching LoRA into T5 instantiate the lora linear layer on the same device as the regular linear layer get lora rank from state dict tests fmt can create lora layer in float32 even when rest of model is float16 fix loading model hook remove load_lora_weights_ and T5 dispatching remove Unet#attn_processors_state_dict docstrings --- examples/dreambooth/train_dreambooth_lora.py | 146 +++--- src/diffusers/loaders.py | 502 +++++++++---------- src/diffusers/models/attention_processor.py | 6 +- src/diffusers/utils/__init__.py | 1 - src/diffusers/utils/constants.py | 1 - tests/models/test_lora_layers.py | 154 ++++-- 6 files changed, 435 insertions(+), 375 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 72fcfa648b48..dbae2483bdca 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -23,6 +23,7 @@ import shutil import warnings from pathlib import Path +from typing import Dict import numpy as np import torch @@ -50,7 +51,10 @@ StableDiffusionPipeline, UNet2DConditionModel, ) -from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin +from diffusers.loaders import ( + LoraLoaderMixin, + text_encoder_lora_state_dict, +) from diffusers.models.attention_processor import ( AttnAddedKVProcessor, AttnAddedKVProcessor2_0, @@ -60,7 +64,7 @@ SlicedAttnAddedKVProcessor, ) from diffusers.optimization import get_scheduler -from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available +from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -647,6 +651,22 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte return prompt_embeds +def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]: + r""" + Returns: + a state dict containing just the attention processor parameters. + """ + attn_processors = unet.attn_processors + + attn_processors_state_dict = {} + + for attn_processor_key, attn_processor in attn_processors.items(): + for parameter_key, parameter in attn_processor.state_dict().items(): + attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter + + return attn_processors_state_dict + + def main(args): logging_dir = Path(args.output_dir, args.logging_dir) @@ -827,6 +847,7 @@ def main(args): # Set correct lora layers unet_lora_attn_procs = {} + unet_lora_parameters = [] for name, attn_processor in unet.attn_processors.items(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): @@ -844,31 +865,17 @@ def main(args): lora_attn_processor_class = ( LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) - unet_lora_attn_procs[name] = lora_attn_processor_class( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim - ) + module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + unet_lora_attn_procs[name] = module + unet_lora_parameters.extend(module.parameters()) unet.set_attn_processor(unet_lora_attn_procs) - unet_lora_layers = AttnProcsLayers(unet.attn_processors) # The text encoder comes from 🤗 transformers, so we cannot directly modify it. - # So, instead, we monkey-patch the forward calls of its attention-blocks. For this, - # we first load a dummy pipeline with the text encoder and then do the monkey-patching. - text_encoder_lora_layers = None + # So, instead, we monkey-patch the forward calls of its attention-blocks. if args.train_text_encoder: - text_lora_attn_procs = {} - for name, module in text_encoder.named_modules(): - if name.endswith(TEXT_ENCODER_ATTN_MODULE): - text_lora_attn_procs[name] = LoRAAttnProcessor( - hidden_size=module.out_proj.out_features, cross_attention_dim=None - ) - text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) - temp_pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, text_encoder=text_encoder - ) - temp_pipeline._modify_text_encoder(text_lora_attn_procs) - text_encoder = temp_pipeline.text_encoder - del temp_pipeline + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): @@ -877,23 +884,13 @@ def save_model_hook(models, weights, output_dir): unet_lora_layers_to_save = None text_encoder_lora_layers_to_save = None - if args.train_text_encoder: - text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys() - unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys() - for model in models: - state_dict = model.state_dict() - - if ( - text_encoder_lora_layers is not None - and text_encoder_keys is not None - and state_dict.keys() == text_encoder_keys - ): - # text encoder - text_encoder_lora_layers_to_save = state_dict - elif state_dict.keys() == unet_keys: - # unet - unet_lora_layers_to_save = state_dict + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_lora_layers_to_save = unet_attn_processors_state_dict(model) + elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): + text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again weights.pop() @@ -905,27 +902,24 @@ def save_model_hook(models, weights, output_dir): ) def load_model_hook(models, input_dir): - # Note we DON'T pass the unet and text encoder here an purpose - # so that the we don't accidentally override the LoRA layers of - # unet_lora_layers and text_encoder_lora_layers which are stored in `models` - # with new torch.nn.Modules / weights. We simply use the pipeline class as - # an easy way to load the lora checkpoints - temp_pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - revision=args.revision, - torch_dtype=weight_dtype, - ) - temp_pipeline.load_lora_weights(input_dir) + unet_ = None + text_encoder_ = None - # load lora weights into models - models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict()) - if len(models) > 1: - models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict()) + while len(models) > 0: + model = models.pop() - # delete temporary pipeline and pop models - del temp_pipeline - for _ in range(len(models)): - models.pop() + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_ = model + elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): + text_encoder_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_) + LoraLoaderMixin.load_lora_into_text_encoder( + lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_ + ) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -955,9 +949,9 @@ def load_model_hook(models, input_dir): # Optimizer creation params_to_optimize = ( - itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) + itertools.chain(unet_lora_parameters, text_lora_parameters) if args.train_text_encoder - else unet_lora_layers.parameters() + else unet_lora_parameters ) optimizer = optimizer_class( params_to_optimize, @@ -1046,12 +1040,12 @@ def compute_text_embeddings(prompt): # Prepare everything with our `accelerator`. if args.train_text_encoder: - unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler ) else: - unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet_lora_layers, optimizer, train_dataloader, lr_scheduler + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. @@ -1200,9 +1194,9 @@ def compute_text_embeddings(prompt): accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( - itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) + itertools.chain(unet_lora_parameters, text_lora_parameters) if args.train_text_encoder - else unet_lora_layers.parameters() + else unet_lora_parameters ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() @@ -1291,15 +1285,17 @@ def compute_text_embeddings(prompt): pipeline_args = {"prompt": args.validation_prompt} if args.validation_images is None: - images = [ - pipeline(**pipeline_args, generator=generator).images[0] - for _ in range(args.num_validation_images) - ] + images = [] + for _ in range(args.num_validation_images): + with torch.cuda.amp.autocast(): + image = pipeline(**pipeline_args, generator=generator).images[0] + images.append(image) else: images = [] for image in args.validation_images: image = Image.open(image) - image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + with torch.cuda.amp.autocast(): + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] images.append(image) for tracker in accelerator.trackers: @@ -1322,12 +1318,16 @@ def compute_text_embeddings(prompt): # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) unet = unet.to(torch.float32) - unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) + unet_lora_layers = unet_attn_processors_state_dict(unet) - if text_encoder is not None: + if text_encoder is not None and args.train_text_encoder: + text_encoder = accelerator.unwrap_model(text_encoder) text_encoder = text_encoder.to(torch.float32) - text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers) + text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder) + else: + text_encoder_lora_layers = None LoraLoaderMixin.save_lora_weights( save_directory=args.output_dir, diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 1bdd33fa80cb..0cfe94c9b2b6 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -20,6 +20,7 @@ import torch import torch.nn.functional as F from huggingface_hub import hf_hub_download +from torch import nn from .models.attention_processor import ( AttnAddedKVProcessor, @@ -29,6 +30,7 @@ LoRAAttnAddedKVProcessor, LoRAAttnProcessor, LoRAAttnProcessor2_0, + LoRALinearLayer, LoRAXFormersAttnProcessor, SlicedAttnAddedKVProcessor, XFormersAttnProcessor, @@ -36,7 +38,6 @@ from .utils import ( DIFFUSERS_CACHE, HF_HUB_OFFLINE, - TEXT_ENCODER_ATTN_MODULE, _get_model_file, deprecate, is_safetensors_available, @@ -49,7 +50,7 @@ import safetensors if is_transformers_available(): - from transformers import PreTrainedModel, PreTrainedTokenizer + from transformers import CLIPTextModel, PreTrainedModel, PreTrainedTokenizer logger = logging.get_logger(__name__) @@ -67,6 +68,64 @@ CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors" +class PatchedLoraProjection(nn.Module): + def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None): + super().__init__() + self.regular_linear_layer = regular_linear_layer + + device = self.regular_linear_layer.weight.device + + if dtype is None: + dtype = self.regular_linear_layer.weight.dtype + + self.lora_linear_layer = LoRALinearLayer( + self.regular_linear_layer.in_features, + self.regular_linear_layer.out_features, + network_alpha=network_alpha, + device=device, + dtype=dtype, + rank=rank, + ) + + self.lora_scale = lora_scale + + def forward(self, input): + return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input) + + +def text_encoder_attn_modules(text_encoder): + attn_modules = [] + + if isinstance(text_encoder, CLIPTextModel): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + else: + raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}") + + return attn_modules + + +def text_encoder_lora_state_dict(text_encoder): + state_dict = {} + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v + + return state_dict + + class AttnProcsLayers(torch.nn.Module): def __init__(self, state_dict: Dict[str, torch.Tensor]): super().__init__() @@ -744,9 +803,48 @@ class LoraLoaderMixin: unet_name = UNET_NAME def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into self.unet and self.text_encoder. + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + + See [`~loaders.LoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into + `self.unet`. + + See [`~loaders.LoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded + into `self.text_encoder`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.LoraLoaderMixin.lora_state_dict`]. + + kwargs: + See [`~loaders.LoraLoaderMixin.lora_state_dict`]. + """ + state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet) + self.load_lora_into_text_encoder( + state_dict, network_alpha=network_alpha, text_encoder=self.text_encoder, lora_scale=self.lora_scale + ) + + @classmethod + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): r""" - Load pretrained LoRA attention processor layers into [`UNet2DConditionModel`] and - [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). + Return state dict for lora weights + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): @@ -801,9 +899,6 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) - # set lora scale to a reasonable default - self._lora_scale = 1.0 - if use_safetensors and not is_safetensors_available(): raise ValueError( "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" @@ -840,7 +935,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") - except IOError as e: + except (IOError, safetensors.SafetensorError) as e: if not allow_pickle: raise e # try loading non-safetensors weights @@ -866,286 +961,182 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # Convert kohya-ss Style LoRA attn procs to diffusers attn procs network_alpha = None if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()): - state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict) + state_dict, network_alpha = cls._convert_kohya_lora_to_diffusers(state_dict) + + return state_dict, network_alpha + + @classmethod + def load_lora_into_unet(cls, state_dict, network_alpha, unet): + """ + This will load the LoRA layers specified in `state_dict` into `unet` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + network_alpha (`float`): + See `LoRALinearLayer` for more details. + unet (`UNet2DConditionModel`): + The UNet model to load the LoRA layers into. + """ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) - if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys): + if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): # Load the layers corresponding to UNet. - unet_keys = [k for k in keys if k.startswith(self.unet_name)] - logger.info(f"Loading {self.unet_name}.") + unet_keys = [k for k in keys if k.startswith(cls.unet_name)] + logger.info(f"Loading {cls.unet_name}.") unet_lora_state_dict = { - k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys + k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys } - self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha) - - # Load the layers corresponding to text encoder and make necessary adjustments. - text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)] - text_encoder_lora_state_dict = { - k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys - } - if len(text_encoder_lora_state_dict) > 0: - logger.info(f"Loading {self.text_encoder_name}.") - attn_procs_text_encoder = self._load_text_encoder_attn_procs( - text_encoder_lora_state_dict, network_alpha=network_alpha - ) - self._modify_text_encoder(attn_procs_text_encoder) - - # save lora attn procs of text encoder so that it can be easily retrieved - self._text_encoder_lora_attn_procs = attn_procs_text_encoder + unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha) # Otherwise, we're dealing with the old format. This means the `state_dict` should only # contain the module names of the `unet` as its keys WITHOUT any prefix. elif not all( - key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() + key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in state_dict.keys() ): - self.unet.load_attn_procs(state_dict) + unet.load_attn_procs(state_dict) warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." warnings.warn(warn_message) + @classmethod + def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key shoult be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alpha (`float`): + See `LoRALinearLayer` for more details. + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + """ + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): + # Load the layers corresponding to text encoder and make necessary adjustments. + text_encoder_keys = [k for k in keys if k.startswith(cls.text_encoder_name)] + text_encoder_lora_state_dict = { + k.replace(f"{cls.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys + } + if len(text_encoder_lora_state_dict) > 0: + logger.info(f"Loading {cls.text_encoder_name}.") + + if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()): + # Convert from the old naming convention to the new naming convention. + # + # Previously, the old LoRA layers were stored on the state dict at the + # same level as the attention block i.e. + # `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`. + # + # This is no actual module at that point, they were monkey patched on to the + # existing module. We want to be able to load them via their actual state dict. + # They're in `PatchedLoraProjection.lora_linear_layer` now. + for name, _ in text_encoder_attn_modules(text_encoder): + text_encoder_lora_state_dict[ + f"{name}.q_proj.lora_linear_layer.up.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight") + text_encoder_lora_state_dict[ + f"{name}.k_proj.lora_linear_layer.up.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight") + text_encoder_lora_state_dict[ + f"{name}.v_proj.lora_linear_layer.up.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight") + text_encoder_lora_state_dict[ + f"{name}.out_proj.lora_linear_layer.up.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight") + + text_encoder_lora_state_dict[ + f"{name}.q_proj.lora_linear_layer.down.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight") + text_encoder_lora_state_dict[ + f"{name}.k_proj.lora_linear_layer.down.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight") + text_encoder_lora_state_dict[ + f"{name}.v_proj.lora_linear_layer.down.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight") + text_encoder_lora_state_dict[ + f"{name}.out_proj.lora_linear_layer.down.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight") + + rank = text_encoder_lora_state_dict[ + "text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight" + ].shape[1] + + cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank) + + # set correct dtype & device + text_encoder_lora_state_dict = { + k: v.to(device=text_encoder.device, dtype=text_encoder.dtype) + for k, v in text_encoder_lora_state_dict.items() + } + + load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) + if len(load_state_dict_results.unexpected_keys) != 0: + raise ValueError( + f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" + ) + @property def lora_scale(self) -> float: # property function that returns the lora scale which can be set at run time by the pipeline. # if _lora_scale has not been set, return 1 return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 - @property - def text_encoder_lora_attn_procs(self): - if hasattr(self, "_text_encoder_lora_attn_procs"): - return self._text_encoder_lora_attn_procs - return - - def _remove_text_encoder_monkey_patch(self): - # Loop over the CLIPAttention module of text_encoder - for name, attn_module in self.text_encoder.named_modules(): - if name.endswith(TEXT_ENCODER_ATTN_MODULE): - # Loop over the LoRA layers - for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items(): - # Retrieve the q/k/v/out projection of CLIPAttention - module = attn_module.get_submodule(text_encoder_attr) - if hasattr(module, "old_forward"): - # restore original `forward` to remove monkey-patch - module.forward = module.old_forward - delattr(module, "old_forward") - - def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): + @classmethod + def _remove_text_encoder_monkey_patch(cls, text_encoder): + for _, attn_module in text_encoder_attn_modules(text_encoder): + if isinstance(attn_module.q_proj, PatchedLoraProjection): + attn_module.q_proj = attn_module.q_proj.regular_linear_layer + attn_module.k_proj = attn_module.k_proj.regular_linear_layer + attn_module.v_proj = attn_module.v_proj.regular_linear_layer + attn_module.out_proj = attn_module.out_proj.regular_linear_layer + + @classmethod + def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, rank=4, dtype=None): r""" Monkey-patches the forward passes of attention modules of the text encoder. - - Parameters: - attn_processors: Dict[str, `LoRAAttnProcessor`]: - A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`]. """ # First, remove any monkey-patch that might have been applied before - self._remove_text_encoder_monkey_patch() - - # Loop over the CLIPAttention module of text_encoder - for name, attn_module in self.text_encoder.named_modules(): - if name.endswith(TEXT_ENCODER_ATTN_MODULE): - # Loop over the LoRA layers - for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items(): - # Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer. - module = attn_module.get_submodule(text_encoder_attr) - lora_layer = attn_processors[name].get_submodule(attn_proc_attr) - - # save old_forward to module that can be used to remove monkey-patch - old_forward = module.old_forward = module.forward - - # create a new scope that locks in the old_forward, lora_layer value for each new_forward function - # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060 - def make_new_forward(old_forward, lora_layer): - def new_forward(x): - result = old_forward(x) + self.lora_scale * lora_layer(x) - return result - - return new_forward - - # Monkey-patch. - module.forward = make_new_forward(old_forward, lora_layer) - - @property - def _lora_attn_processor_attr_to_text_encoder_attr(self): - return { - "to_q_lora": "q_proj", - "to_k_lora": "k_proj", - "to_v_lora": "v_proj", - "to_out_lora": "out_proj", - } - - def _load_text_encoder_attn_procs( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs - ): - r""" - Load pretrained attention processor layers for - [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). - - - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. - - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., - `./my_model_directory/`. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `diffusers-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - subfolder (`str`, *optional*, defaults to `""`): - In case the relevant files are located inside a subfolder of the model repo (either remote in - huggingface.co or downloaded locally), you can specify the folder name here. - mirror (`str`, *optional*): - Mirror source to accelerate downloads in China. If you are from China and have an accessibility - problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. - Please refer to the mirror site for more information. - - Returns: - `Dict[name, LoRAAttnProcessor]`: Mapping between the module names and their corresponding - [`LoRAAttnProcessor`]. - - - - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models). - - - """ + cls._remove_text_encoder_monkey_patch(text_encoder) - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - network_alpha = kwargs.pop("network_alpha", None) + lora_parameters = [] - if use_safetensors and not is_safetensors_available(): - raise ValueError( - "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" + for _, attn_module in text_encoder_attn_modules(text_encoder): + attn_module.q_proj = PatchedLoraProjection( + attn_module.q_proj, lora_scale, network_alpha, rank=rank, dtype=dtype ) + lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters()) - allow_pickle = False - if use_safetensors is None: - use_safetensors = is_safetensors_available() - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - model_file = None - if not isinstance(pretrained_model_name_or_path_or_dict, dict): - # Let's first try to load .safetensors weights - if (use_safetensors and weight_name is None) or ( - weight_name is not None and weight_name.endswith(".safetensors") - ): - try: - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = safetensors.torch.load_file(model_file, device="cpu") - except IOError as e: - if not allow_pickle: - raise e - # try loading non-safetensors weights - pass - if model_file is None: - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = torch.load(model_file, map_location="cpu") - else: - state_dict = pretrained_model_name_or_path_or_dict - - # fill attn processors - attn_processors = {} - - is_lora = all("lora" in k for k in state_dict.keys()) - - if is_lora: - lora_grouped_dict = defaultdict(dict) - for key, value in state_dict.items(): - attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) - lora_grouped_dict[attn_processor_key][sub_key] = value - - for key, value_dict in lora_grouped_dict.items(): - rank = value_dict["to_k_lora.down.weight"].shape[0] - cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] - hidden_size = value_dict["to_k_lora.up.weight"].shape[0] + attn_module.k_proj = PatchedLoraProjection( + attn_module.k_proj, lora_scale, network_alpha, rank=rank, dtype=dtype + ) + lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters()) - attn_processor_class = ( - LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor - ) - attn_processors[key] = attn_processor_class( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - rank=rank, - network_alpha=network_alpha, - ) - attn_processors[key].load_state_dict(value_dict) + attn_module.v_proj = PatchedLoraProjection( + attn_module.v_proj, lora_scale, network_alpha, rank=rank, dtype=dtype + ) + lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters()) - else: - raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.") + attn_module.out_proj = PatchedLoraProjection( + attn_module.out_proj, lora_scale, network_alpha, rank=rank, dtype=dtype + ) + lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters()) - # set correct dtype & device - attn_processors = { - k: v.to(device=self.device, dtype=self.text_encoder.dtype) for k, v in attn_processors.items() - } - return attn_processors + return lora_parameters @classmethod def save_lora_weights( @@ -1225,7 +1216,8 @@ def save_function(weights, filename): save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") - def _convert_kohya_lora_to_diffusers(self, state_dict): + @classmethod + def _convert_kohya_lora_to_diffusers(cls, state_dict): unet_state_dict = {} te_state_dict = {} network_alpha = None diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 0bc7886c2653..9b3a06436c6c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -506,14 +506,14 @@ def __call__( class LoRALinearLayer(nn.Module): - def __init__(self, in_features, out_features, rank=4, network_alpha=None): + def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): super().__init__() if rank > min(in_features, out_features): raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") - self.down = nn.Linear(in_features, rank, bias=False) - self.up = nn.Linear(rank, out_features, bias=False) + self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) + self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning self.network_alpha = network_alpha diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 36cbe82f79e7..d145392e9eb0 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -30,7 +30,6 @@ ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, - TEXT_ENCODER_ATTN_MODULE, WEIGHTS_NAME, ) from .deprecation_utils import deprecate diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 3c641a259a81..b9e60a2a873b 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -30,4 +30,3 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] -TEXT_ENCODER_ATTN_MODULE = ".self_attn" diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index aaacf1e68f9f..894d82a4fcf6 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -12,18 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import gc import os import tempfile import unittest +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from huggingface_hub.repocard import RepoCard from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel -from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin +from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin, PatchedLoraProjection, text_encoder_attn_modules from diffusers.models.attention_processor import ( Attention, AttnProcessor, @@ -33,7 +34,8 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) -from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, floats_tensor, torch_device +from diffusers.utils import floats_tensor, torch_device +from diffusers.utils.testing_utils import require_torch_gpu, slow def create_unet_lora_layers(unet: nn.Module): @@ -63,11 +65,15 @@ def create_text_encoder_lora_attn_procs(text_encoder: nn.Module): lora_attn_processor_class = ( LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) - for name, module in text_encoder.named_modules(): - if name.endswith(TEXT_ENCODER_ATTN_MODULE): - text_lora_attn_procs[name] = lora_attn_processor_class( - hidden_size=module.out_proj.out_features, cross_attention_dim=None - ) + for name, module in text_encoder_attn_modules(text_encoder): + if isinstance(module.out_proj, nn.Linear): + out_features = module.out_proj.out_features + elif isinstance(module.out_proj, PatchedLoraProjection): + out_features = module.out_proj.regular_linear_layer.out_features + else: + assert False, module.out_proj.__class__ + + text_lora_attn_procs[name] = lora_attn_processor_class(hidden_size=out_features, cross_attention_dim=None) return text_lora_attn_procs @@ -77,17 +83,13 @@ def create_text_encoder_lora_layers(text_encoder: nn.Module): return text_encoder_lora_layers -def set_lora_up_weights(text_lora_attn_procs, randn_weight=False): - for _, attn_proc in text_lora_attn_procs.items(): - # set up.weights - for layer_name, layer_module in attn_proc.named_modules(): - if layer_name.endswith("_lora"): - weight = ( - torch.randn_like(layer_module.up.weight) - if randn_weight - else torch.zeros_like(layer_module.up.weight) - ) - layer_module.up.weight = torch.nn.Parameter(weight) +def set_lora_weights(text_lora_attn_parameters, randn_weight=False): + with torch.no_grad(): + for parameter in text_lora_attn_parameters: + if randn_weight: + parameter[:] = torch.randn_like(parameter) + else: + torch.zero_(parameter) class LoraLoaderMixinTests(unittest.TestCase): @@ -281,16 +283,10 @@ def test_text_encoder_lora_monkey_patch(self): outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] assert outputs_without_lora.shape == (1, 77, 32) - # create lora_attn_procs with zeroed out up.weights - text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) - set_lora_up_weights(text_attn_procs, randn_weight=False) - # monkey patch - pipe._modify_text_encoder(text_attn_procs) + params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) - # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. - del text_attn_procs - gc.collect() + set_lora_weights(params, randn_weight=False) # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] @@ -301,15 +297,12 @@ def test_text_encoder_lora_monkey_patch(self): ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" # create lora_attn_procs with randn up.weights - text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) - set_lora_up_weights(text_attn_procs, randn_weight=True) + create_text_encoder_lora_attn_procs(pipe.text_encoder) # monkey patch - pipe._modify_text_encoder(text_attn_procs) + params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) - # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. - del text_attn_procs - gc.collect() + set_lora_weights(params, randn_weight=True) # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] @@ -329,16 +322,10 @@ def test_text_encoder_lora_remove_monkey_patch(self): outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] assert outputs_without_lora.shape == (1, 77, 32) - # create lora_attn_procs with randn up.weights - text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) - set_lora_up_weights(text_attn_procs, randn_weight=True) - # monkey patch - pipe._modify_text_encoder(text_attn_procs) + params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) - # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. - del text_attn_procs - gc.collect() + set_lora_weights(params, randn_weight=True) # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] @@ -349,7 +336,7 @@ def test_text_encoder_lora_remove_monkey_patch(self): ), "lora outputs should be different to without lora outputs" # remove monkey patch - pipe._remove_text_encoder_monkey_patch() + pipe._remove_text_encoder_monkey_patch(pipe.text_encoder) # inference with removed lora outputs_without_lora_removed = pipe.text_encoder(**dummy_tokens)[0] @@ -467,3 +454,86 @@ def test_lora_save_load_with_xformers(self): # Outputs shouldn't match. self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + + +@slow +@require_torch_gpu +class LoraIntegrationTests(unittest.TestCase): + def test_dreambooth_old_format(self): + generator = torch.Generator("cpu").manual_seed(0) + + lora_model_id = "hf-internal-testing/lora_dreambooth_dog_example" + card = RepoCard.load(lora_model_id) + base_model_id = card.data.to_dict()["base_model"] + + pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None) + pipe = pipe.to(torch_device) + pipe.load_lora_weights(lora_model_id) + + images = pipe( + "A photo of a sks dog floating in the river", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + + expected = np.array([0.7207, 0.6787, 0.6010, 0.7478, 0.6838, 0.6064, 0.6984, 0.6443, 0.5785]) + + self.assertTrue(np.allclose(images, expected, atol=1e-4)) + + def test_dreambooth_text_encoder_new_format(self): + generator = torch.Generator().manual_seed(0) + + lora_model_id = "hf-internal-testing/lora-trained" + card = RepoCard.load(lora_model_id) + base_model_id = card.data.to_dict()["base_model"] + + pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None) + pipe = pipe.to(torch_device) + pipe.load_lora_weights(lora_model_id) + + images = pipe("A photo of a sks dog", output_type="np", generator=generator, num_inference_steps=2).images + + images = images[0, -3:, -3:, -1].flatten() + + expected = np.array([0.6628, 0.6138, 0.5390, 0.6625, 0.6130, 0.5463, 0.6166, 0.5788, 0.5359]) + + self.assertTrue(np.allclose(images, expected, atol=1e-4)) + + def test_a1111(self): + generator = torch.Generator().manual_seed(0) + + pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None).to( + torch_device + ) + lora_model_id = "hf-internal-testing/civitai-light-shadow-lora" + lora_filename = "light_and_shadow.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + + expected = np.array([0.3743, 0.3893, 0.3835, 0.3891, 0.3949, 0.3649, 0.3858, 0.3802, 0.3245]) + + self.assertTrue(np.allclose(images, expected, atol=1e-4)) + + def test_vanilla_funetuning(self): + generator = torch.Generator().manual_seed(0) + + lora_model_id = "hf-internal-testing/sd-model-finetuned-lora-t4" + card = RepoCard.load(lora_model_id) + base_model_id = card.data.to_dict()["base_model"] + + pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None) + pipe = pipe.to(torch_device) + pipe.load_lora_weights(lora_model_id) + + images = pipe("A pokemon with blue eyes.", output_type="np", generator=generator, num_inference_steps=2).images + + images = images[0, -3:, -3:, -1].flatten() + + expected = np.array([0.7406, 0.699, 0.5963, 0.7493, 0.7045, 0.6096, 0.6886, 0.6388, 0.583]) + + self.assertTrue(np.allclose(images, expected, atol=1e-4)) From 7f23d518b05f0d7bdda3e18cb43de5961a82d3c9 Mon Sep 17 00:00:00 2001 From: William Berman Date: Wed, 5 Jul 2023 19:27:49 -0700 Subject: [PATCH 2/3] text encoder monkeypatch class method --- src/diffusers/loaders.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 0cfe94c9b2b6..3962b524139c 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1095,8 +1095,11 @@ def lora_scale(self) -> float: # if _lora_scale has not been set, return 1 return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 + def _remove_text_encoder_monkey_patch(self): + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) + @classmethod - def _remove_text_encoder_monkey_patch(cls, text_encoder): + def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): for _, attn_module in text_encoder_attn_modules(text_encoder): if isinstance(attn_module.q_proj, PatchedLoraProjection): attn_module.q_proj = attn_module.q_proj.regular_linear_layer @@ -1111,7 +1114,7 @@ def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, ra """ # First, remove any monkey-patch that might have been applied before - cls._remove_text_encoder_monkey_patch(text_encoder) + cls._remove_text_encoder_monkey_patch_classmethod(text_encoder) lora_parameters = [] From 104ae132261b6baf45795c597195827ca8d15d26 Mon Sep 17 00:00:00 2001 From: William Berman Date: Wed, 5 Jul 2023 20:26:57 -0700 Subject: [PATCH 3/3] fix test --- tests/models/test_lora_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 894d82a4fcf6..3190a123898c 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -336,7 +336,7 @@ def test_text_encoder_lora_remove_monkey_patch(self): ), "lora outputs should be different to without lora outputs" # remove monkey patch - pipe._remove_text_encoder_monkey_patch(pipe.text_encoder) + pipe._remove_text_encoder_monkey_patch() # inference with removed lora outputs_without_lora_removed = pipe.text_encoder(**dummy_tokens)[0]