diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index e3222357ae9b..8cc17281d016 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -23,6 +23,7 @@ import shutil import warnings from pathlib import Path +from typing import Dict import numpy as np import torch @@ -50,7 +51,10 @@ StableDiffusionPipeline, UNet2DConditionModel, ) -from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin +from diffusers.loaders import ( + LoraLoaderMixin, + text_encoder_lora_state_dict, +) from diffusers.models.attention_processor import ( AttnAddedKVProcessor, AttnAddedKVProcessor2_0, @@ -60,7 +64,7 @@ SlicedAttnAddedKVProcessor, ) from diffusers.optimization import get_scheduler -from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available +from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -653,6 +657,22 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte return prompt_embeds +def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]: + r""" + Returns: + a state dict containing just the attention processor parameters. + """ + attn_processors = unet.attn_processors + + attn_processors_state_dict = {} + + for attn_processor_key, attn_processor in attn_processors.items(): + for parameter_key, parameter in attn_processor.state_dict().items(): + attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter + + return attn_processors_state_dict + + def main(args): logging_dir = Path(args.output_dir, args.logging_dir) @@ -833,6 +853,7 @@ def main(args): # Set correct lora layers unet_lora_attn_procs = {} + unet_lora_parameters = [] for name, attn_processor in unet.attn_processors.items(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): @@ -850,35 +871,18 @@ def main(args): lora_attn_processor_class = ( LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) - unet_lora_attn_procs[name] = lora_attn_processor_class( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - rank=args.rank, - ) + + module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + unet_lora_attn_procs[name] = module + unet_lora_parameters.extend(module.parameters()) unet.set_attn_processor(unet_lora_attn_procs) - unet_lora_layers = AttnProcsLayers(unet.attn_processors) # The text encoder comes from 🤗 transformers, so we cannot directly modify it. - # So, instead, we monkey-patch the forward calls of its attention-blocks. For this, - # we first load a dummy pipeline with the text encoder and then do the monkey-patching. - text_encoder_lora_layers = None + # So, instead, we monkey-patch the forward calls of its attention-blocks. if args.train_text_encoder: - text_lora_attn_procs = {} - for name, module in text_encoder.named_modules(): - if name.endswith(TEXT_ENCODER_ATTN_MODULE): - text_lora_attn_procs[name] = LoRAAttnProcessor( - hidden_size=module.out_proj.out_features, - cross_attention_dim=None, - rank=args.rank, - ) - text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) - temp_pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, text_encoder=text_encoder - ) - temp_pipeline._modify_text_encoder(text_lora_attn_procs) - text_encoder = temp_pipeline.text_encoder - del temp_pipeline + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): @@ -887,23 +891,13 @@ def save_model_hook(models, weights, output_dir): unet_lora_layers_to_save = None text_encoder_lora_layers_to_save = None - if args.train_text_encoder: - text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys() - unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys() - for model in models: - state_dict = model.state_dict() - - if ( - text_encoder_lora_layers is not None - and text_encoder_keys is not None - and state_dict.keys() == text_encoder_keys - ): - # text encoder - text_encoder_lora_layers_to_save = state_dict - elif state_dict.keys() == unet_keys: - # unet - unet_lora_layers_to_save = state_dict + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_lora_layers_to_save = unet_attn_processors_state_dict(model) + elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): + text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again weights.pop() @@ -915,27 +909,24 @@ def save_model_hook(models, weights, output_dir): ) def load_model_hook(models, input_dir): - # Note we DON'T pass the unet and text encoder here an purpose - # so that the we don't accidentally override the LoRA layers of - # unet_lora_layers and text_encoder_lora_layers which are stored in `models` - # with new torch.nn.Modules / weights. We simply use the pipeline class as - # an easy way to load the lora checkpoints - temp_pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - revision=args.revision, - torch_dtype=weight_dtype, - ) - temp_pipeline.load_lora_weights(input_dir) + unet_ = None + text_encoder_ = None - # load lora weights into models - models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict()) - if len(models) > 1: - models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict()) + while len(models) > 0: + model = models.pop() - # delete temporary pipeline and pop models - del temp_pipeline - for _ in range(len(models)): - models.pop() + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_ = model + elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): + text_encoder_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_) + LoraLoaderMixin.load_lora_into_text_encoder( + lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_ + ) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -965,9 +956,9 @@ def load_model_hook(models, input_dir): # Optimizer creation params_to_optimize = ( - itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) + itertools.chain(unet_lora_parameters, text_lora_parameters) if args.train_text_encoder - else unet_lora_layers.parameters() + else unet_lora_parameters ) optimizer = optimizer_class( params_to_optimize, @@ -1056,12 +1047,12 @@ def compute_text_embeddings(prompt): # Prepare everything with our `accelerator`. if args.train_text_encoder: - unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler ) else: - unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet_lora_layers, optimizer, train_dataloader, lr_scheduler + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. @@ -1210,9 +1201,9 @@ def compute_text_embeddings(prompt): accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( - itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) + itertools.chain(unet_lora_parameters, text_lora_parameters) if args.train_text_encoder - else unet_lora_layers.parameters() + else unet_lora_parameters ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() @@ -1301,15 +1292,17 @@ def compute_text_embeddings(prompt): pipeline_args = {"prompt": args.validation_prompt} if args.validation_images is None: - images = [ - pipeline(**pipeline_args, generator=generator).images[0] - for _ in range(args.num_validation_images) - ] + images = [] + for _ in range(args.num_validation_images): + with torch.cuda.amp.autocast(): + image = pipeline(**pipeline_args, generator=generator).images[0] + images.append(image) else: images = [] for image in args.validation_images: image = Image.open(image) - image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + with torch.cuda.amp.autocast(): + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] images.append(image) for tracker in accelerator.trackers: @@ -1332,12 +1325,16 @@ def compute_text_embeddings(prompt): # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) unet = unet.to(torch.float32) - unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) + unet_lora_layers = unet_attn_processors_state_dict(unet) - if text_encoder is not None: + if text_encoder is not None and args.train_text_encoder: + text_encoder = accelerator.unwrap_model(text_encoder) text_encoder = text_encoder.to(torch.float32) - text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers) + text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder) + else: + text_encoder_lora_layers = None LoraLoaderMixin.save_lora_weights( save_directory=args.output_dir, diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 525bb446b77e..89ce88455e20 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -20,6 +20,7 @@ import torch import torch.nn.functional as F from huggingface_hub import hf_hub_download +from torch import nn from .models.attention_processor import ( AttnAddedKVProcessor, @@ -29,6 +30,7 @@ LoRAAttnAddedKVProcessor, LoRAAttnProcessor, LoRAAttnProcessor2_0, + LoRALinearLayer, LoRAXFormersAttnProcessor, SlicedAttnAddedKVProcessor, XFormersAttnProcessor, @@ -36,7 +38,6 @@ from .utils import ( DIFFUSERS_CACHE, HF_HUB_OFFLINE, - TEXT_ENCODER_ATTN_MODULE, _get_model_file, deprecate, is_safetensors_available, @@ -49,7 +50,7 @@ import safetensors if is_transformers_available(): - from transformers import PreTrainedModel, PreTrainedTokenizer + from transformers import CLIPTextModel, PreTrainedModel, PreTrainedTokenizer logger = logging.get_logger(__name__) @@ -67,6 +68,64 @@ CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors" +class PatchedLoraProjection(nn.Module): + def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None): + super().__init__() + self.regular_linear_layer = regular_linear_layer + + device = self.regular_linear_layer.weight.device + + if dtype is None: + dtype = self.regular_linear_layer.weight.dtype + + self.lora_linear_layer = LoRALinearLayer( + self.regular_linear_layer.in_features, + self.regular_linear_layer.out_features, + network_alpha=network_alpha, + device=device, + dtype=dtype, + rank=rank, + ) + + self.lora_scale = lora_scale + + def forward(self, input): + return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input) + + +def text_encoder_attn_modules(text_encoder): + attn_modules = [] + + if isinstance(text_encoder, CLIPTextModel): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + else: + raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}") + + return attn_modules + + +def text_encoder_lora_state_dict(text_encoder): + state_dict = {} + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v + + return state_dict + + class AttnProcsLayers(torch.nn.Module): def __init__(self, state_dict: Dict[str, torch.Tensor]): super().__init__() @@ -744,9 +803,48 @@ class LoraLoaderMixin: unet_name = UNET_NAME def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into self.unet and self.text_encoder. + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + + See [`~loaders.LoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into + `self.unet`. + + See [`~loaders.LoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded + into `self.text_encoder`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.LoraLoaderMixin.lora_state_dict`]. + + kwargs: + See [`~loaders.LoraLoaderMixin.lora_state_dict`]. + """ + state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet) + self.load_lora_into_text_encoder( + state_dict, network_alpha=network_alpha, text_encoder=self.text_encoder, lora_scale=self.lora_scale + ) + + @classmethod + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): r""" - Load pretrained LoRA attention processor layers into [`UNet2DConditionModel`] and - [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). + Return state dict for lora weights + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): @@ -801,9 +899,6 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) - # set lora scale to a reasonable default - self._lora_scale = 1.0 - if use_safetensors and not is_safetensors_available(): raise ValueError( "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" @@ -840,7 +935,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") - except IOError as e: + except (IOError, safetensors.SafetensorError) as e: if not allow_pickle: raise e # try loading non-safetensors weights @@ -866,286 +961,185 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # Convert kohya-ss Style LoRA attn procs to diffusers attn procs network_alpha = None if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()): - state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict) + state_dict, network_alpha = cls._convert_kohya_lora_to_diffusers(state_dict) + + return state_dict, network_alpha + + @classmethod + def load_lora_into_unet(cls, state_dict, network_alpha, unet): + """ + This will load the LoRA layers specified in `state_dict` into `unet` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + network_alpha (`float`): + See `LoRALinearLayer` for more details. + unet (`UNet2DConditionModel`): + The UNet model to load the LoRA layers into. + """ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) - if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys): + if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): # Load the layers corresponding to UNet. - unet_keys = [k for k in keys if k.startswith(self.unet_name)] - logger.info(f"Loading {self.unet_name}.") + unet_keys = [k for k in keys if k.startswith(cls.unet_name)] + logger.info(f"Loading {cls.unet_name}.") unet_lora_state_dict = { - k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys + k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys } - self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha) - - # Load the layers corresponding to text encoder and make necessary adjustments. - text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)] - text_encoder_lora_state_dict = { - k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys - } - if len(text_encoder_lora_state_dict) > 0: - logger.info(f"Loading {self.text_encoder_name}.") - attn_procs_text_encoder = self._load_text_encoder_attn_procs( - text_encoder_lora_state_dict, network_alpha=network_alpha - ) - self._modify_text_encoder(attn_procs_text_encoder) - - # save lora attn procs of text encoder so that it can be easily retrieved - self._text_encoder_lora_attn_procs = attn_procs_text_encoder + unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha) # Otherwise, we're dealing with the old format. This means the `state_dict` should only # contain the module names of the `unet` as its keys WITHOUT any prefix. elif not all( - key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() + key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in state_dict.keys() ): - self.unet.load_attn_procs(state_dict) + unet.load_attn_procs(state_dict) warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." warnings.warn(warn_message) + @classmethod + def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key shoult be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alpha (`float`): + See `LoRALinearLayer` for more details. + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + """ + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): + # Load the layers corresponding to text encoder and make necessary adjustments. + text_encoder_keys = [k for k in keys if k.startswith(cls.text_encoder_name)] + text_encoder_lora_state_dict = { + k.replace(f"{cls.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys + } + if len(text_encoder_lora_state_dict) > 0: + logger.info(f"Loading {cls.text_encoder_name}.") + + if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()): + # Convert from the old naming convention to the new naming convention. + # + # Previously, the old LoRA layers were stored on the state dict at the + # same level as the attention block i.e. + # `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`. + # + # This is no actual module at that point, they were monkey patched on to the + # existing module. We want to be able to load them via their actual state dict. + # They're in `PatchedLoraProjection.lora_linear_layer` now. + for name, _ in text_encoder_attn_modules(text_encoder): + text_encoder_lora_state_dict[ + f"{name}.q_proj.lora_linear_layer.up.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight") + text_encoder_lora_state_dict[ + f"{name}.k_proj.lora_linear_layer.up.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight") + text_encoder_lora_state_dict[ + f"{name}.v_proj.lora_linear_layer.up.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight") + text_encoder_lora_state_dict[ + f"{name}.out_proj.lora_linear_layer.up.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight") + + text_encoder_lora_state_dict[ + f"{name}.q_proj.lora_linear_layer.down.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight") + text_encoder_lora_state_dict[ + f"{name}.k_proj.lora_linear_layer.down.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight") + text_encoder_lora_state_dict[ + f"{name}.v_proj.lora_linear_layer.down.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight") + text_encoder_lora_state_dict[ + f"{name}.out_proj.lora_linear_layer.down.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight") + + rank = text_encoder_lora_state_dict[ + "text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight" + ].shape[1] + + cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank) + + # set correct dtype & device + text_encoder_lora_state_dict = { + k: v.to(device=text_encoder.device, dtype=text_encoder.dtype) + for k, v in text_encoder_lora_state_dict.items() + } + + load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) + if len(load_state_dict_results.unexpected_keys) != 0: + raise ValueError( + f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" + ) + @property def lora_scale(self) -> float: # property function that returns the lora scale which can be set at run time by the pipeline. # if _lora_scale has not been set, return 1 return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 - @property - def text_encoder_lora_attn_procs(self): - if hasattr(self, "_text_encoder_lora_attn_procs"): - return self._text_encoder_lora_attn_procs - return - def _remove_text_encoder_monkey_patch(self): - # Loop over the CLIPAttention module of text_encoder - for name, attn_module in self.text_encoder.named_modules(): - if name.endswith(TEXT_ENCODER_ATTN_MODULE): - # Loop over the LoRA layers - for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items(): - # Retrieve the q/k/v/out projection of CLIPAttention - module = attn_module.get_submodule(text_encoder_attr) - if hasattr(module, "old_forward"): - # restore original `forward` to remove monkey-patch - module.forward = module.old_forward - delattr(module, "old_forward") - - def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) + + @classmethod + def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): + for _, attn_module in text_encoder_attn_modules(text_encoder): + if isinstance(attn_module.q_proj, PatchedLoraProjection): + attn_module.q_proj = attn_module.q_proj.regular_linear_layer + attn_module.k_proj = attn_module.k_proj.regular_linear_layer + attn_module.v_proj = attn_module.v_proj.regular_linear_layer + attn_module.out_proj = attn_module.out_proj.regular_linear_layer + + @classmethod + def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, rank=4, dtype=None): r""" Monkey-patches the forward passes of attention modules of the text encoder. - - Parameters: - attn_processors: Dict[str, `LoRAAttnProcessor`]: - A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`]. """ # First, remove any monkey-patch that might have been applied before - self._remove_text_encoder_monkey_patch() - - # Loop over the CLIPAttention module of text_encoder - for name, attn_module in self.text_encoder.named_modules(): - if name.endswith(TEXT_ENCODER_ATTN_MODULE): - # Loop over the LoRA layers - for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items(): - # Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer. - module = attn_module.get_submodule(text_encoder_attr) - lora_layer = attn_processors[name].get_submodule(attn_proc_attr) - - # save old_forward to module that can be used to remove monkey-patch - old_forward = module.old_forward = module.forward - - # create a new scope that locks in the old_forward, lora_layer value for each new_forward function - # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060 - def make_new_forward(old_forward, lora_layer): - def new_forward(x): - result = old_forward(x) + self.lora_scale * lora_layer(x) - return result - - return new_forward - - # Monkey-patch. - module.forward = make_new_forward(old_forward, lora_layer) - - @property - def _lora_attn_processor_attr_to_text_encoder_attr(self): - return { - "to_q_lora": "q_proj", - "to_k_lora": "k_proj", - "to_v_lora": "v_proj", - "to_out_lora": "out_proj", - } - - def _load_text_encoder_attn_procs( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs - ): - r""" - Load pretrained attention processor layers for - [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). - - - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. - - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., - `./my_model_directory/`. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `diffusers-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - subfolder (`str`, *optional*, defaults to `""`): - In case the relevant files are located inside a subfolder of the model repo (either remote in - huggingface.co or downloaded locally), you can specify the folder name here. - mirror (`str`, *optional*): - Mirror source to accelerate downloads in China. If you are from China and have an accessibility - problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. - Please refer to the mirror site for more information. - - Returns: - `Dict[name, LoRAAttnProcessor]`: Mapping between the module names and their corresponding - [`LoRAAttnProcessor`]. - - + cls._remove_text_encoder_monkey_patch_classmethod(text_encoder) - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models). - - - """ + lora_parameters = [] - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - network_alpha = kwargs.pop("network_alpha", None) - - if use_safetensors and not is_safetensors_available(): - raise ValueError( - "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" + for _, attn_module in text_encoder_attn_modules(text_encoder): + attn_module.q_proj = PatchedLoraProjection( + attn_module.q_proj, lora_scale, network_alpha, rank=rank, dtype=dtype ) + lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters()) - allow_pickle = False - if use_safetensors is None: - use_safetensors = is_safetensors_available() - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - model_file = None - if not isinstance(pretrained_model_name_or_path_or_dict, dict): - # Let's first try to load .safetensors weights - if (use_safetensors and weight_name is None) or ( - weight_name is not None and weight_name.endswith(".safetensors") - ): - try: - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = safetensors.torch.load_file(model_file, device="cpu") - except IOError as e: - if not allow_pickle: - raise e - # try loading non-safetensors weights - pass - if model_file is None: - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = torch.load(model_file, map_location="cpu") - else: - state_dict = pretrained_model_name_or_path_or_dict - - # fill attn processors - attn_processors = {} - - is_lora = all("lora" in k for k in state_dict.keys()) - - if is_lora: - lora_grouped_dict = defaultdict(dict) - for key, value in state_dict.items(): - attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) - lora_grouped_dict[attn_processor_key][sub_key] = value - - for key, value_dict in lora_grouped_dict.items(): - rank = value_dict["to_k_lora.down.weight"].shape[0] - cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] - hidden_size = value_dict["to_k_lora.up.weight"].shape[0] + attn_module.k_proj = PatchedLoraProjection( + attn_module.k_proj, lora_scale, network_alpha, rank=rank, dtype=dtype + ) + lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters()) - attn_processor_class = ( - LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor - ) - attn_processors[key] = attn_processor_class( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - rank=rank, - network_alpha=network_alpha, - ) - attn_processors[key].load_state_dict(value_dict) + attn_module.v_proj = PatchedLoraProjection( + attn_module.v_proj, lora_scale, network_alpha, rank=rank, dtype=dtype + ) + lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters()) - else: - raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.") + attn_module.out_proj = PatchedLoraProjection( + attn_module.out_proj, lora_scale, network_alpha, rank=rank, dtype=dtype + ) + lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters()) - # set correct dtype & device - attn_processors = { - k: v.to(device=self.device, dtype=self.text_encoder.dtype) for k, v in attn_processors.items() - } - return attn_processors + return lora_parameters @classmethod def save_lora_weights( @@ -1225,7 +1219,8 @@ def save_function(weights, filename): save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") - def _convert_kohya_lora_to_diffusers(self, state_dict): + @classmethod + def _convert_kohya_lora_to_diffusers(cls, state_dict): unet_state_dict = {} te_state_dict = {} network_alpha = None diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 5b6a161f8466..da2920fa671a 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -506,14 +506,14 @@ def __call__( class LoRALinearLayer(nn.Module): - def __init__(self, in_features, out_features, rank=4, network_alpha=None): + def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): super().__init__() if rank > min(in_features, out_features): raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") - self.down = nn.Linear(in_features, rank, bias=False) - self.up = nn.Linear(rank, out_features, bias=False) + self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) + self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning self.network_alpha = network_alpha diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 7449df99ba80..98fac64497e7 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -30,7 +30,6 @@ ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, - TEXT_ENCODER_ATTN_MODULE, WEIGHTS_NAME, ) from .deprecation_utils import deprecate diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 3c641a259a81..b9e60a2a873b 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -30,4 +30,3 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] -TEXT_ENCODER_ATTN_MODULE = ".self_attn" diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index aaacf1e68f9f..3190a123898c 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -12,18 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import gc import os import tempfile import unittest +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from huggingface_hub.repocard import RepoCard from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel -from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin +from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin, PatchedLoraProjection, text_encoder_attn_modules from diffusers.models.attention_processor import ( Attention, AttnProcessor, @@ -33,7 +34,8 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) -from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, floats_tensor, torch_device +from diffusers.utils import floats_tensor, torch_device +from diffusers.utils.testing_utils import require_torch_gpu, slow def create_unet_lora_layers(unet: nn.Module): @@ -63,11 +65,15 @@ def create_text_encoder_lora_attn_procs(text_encoder: nn.Module): lora_attn_processor_class = ( LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) - for name, module in text_encoder.named_modules(): - if name.endswith(TEXT_ENCODER_ATTN_MODULE): - text_lora_attn_procs[name] = lora_attn_processor_class( - hidden_size=module.out_proj.out_features, cross_attention_dim=None - ) + for name, module in text_encoder_attn_modules(text_encoder): + if isinstance(module.out_proj, nn.Linear): + out_features = module.out_proj.out_features + elif isinstance(module.out_proj, PatchedLoraProjection): + out_features = module.out_proj.regular_linear_layer.out_features + else: + assert False, module.out_proj.__class__ + + text_lora_attn_procs[name] = lora_attn_processor_class(hidden_size=out_features, cross_attention_dim=None) return text_lora_attn_procs @@ -77,17 +83,13 @@ def create_text_encoder_lora_layers(text_encoder: nn.Module): return text_encoder_lora_layers -def set_lora_up_weights(text_lora_attn_procs, randn_weight=False): - for _, attn_proc in text_lora_attn_procs.items(): - # set up.weights - for layer_name, layer_module in attn_proc.named_modules(): - if layer_name.endswith("_lora"): - weight = ( - torch.randn_like(layer_module.up.weight) - if randn_weight - else torch.zeros_like(layer_module.up.weight) - ) - layer_module.up.weight = torch.nn.Parameter(weight) +def set_lora_weights(text_lora_attn_parameters, randn_weight=False): + with torch.no_grad(): + for parameter in text_lora_attn_parameters: + if randn_weight: + parameter[:] = torch.randn_like(parameter) + else: + torch.zero_(parameter) class LoraLoaderMixinTests(unittest.TestCase): @@ -281,16 +283,10 @@ def test_text_encoder_lora_monkey_patch(self): outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] assert outputs_without_lora.shape == (1, 77, 32) - # create lora_attn_procs with zeroed out up.weights - text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) - set_lora_up_weights(text_attn_procs, randn_weight=False) - # monkey patch - pipe._modify_text_encoder(text_attn_procs) + params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) - # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. - del text_attn_procs - gc.collect() + set_lora_weights(params, randn_weight=False) # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] @@ -301,15 +297,12 @@ def test_text_encoder_lora_monkey_patch(self): ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" # create lora_attn_procs with randn up.weights - text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) - set_lora_up_weights(text_attn_procs, randn_weight=True) + create_text_encoder_lora_attn_procs(pipe.text_encoder) # monkey patch - pipe._modify_text_encoder(text_attn_procs) + params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) - # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. - del text_attn_procs - gc.collect() + set_lora_weights(params, randn_weight=True) # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] @@ -329,16 +322,10 @@ def test_text_encoder_lora_remove_monkey_patch(self): outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] assert outputs_without_lora.shape == (1, 77, 32) - # create lora_attn_procs with randn up.weights - text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) - set_lora_up_weights(text_attn_procs, randn_weight=True) - # monkey patch - pipe._modify_text_encoder(text_attn_procs) + params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) - # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. - del text_attn_procs - gc.collect() + set_lora_weights(params, randn_weight=True) # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] @@ -467,3 +454,86 @@ def test_lora_save_load_with_xformers(self): # Outputs shouldn't match. self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + + +@slow +@require_torch_gpu +class LoraIntegrationTests(unittest.TestCase): + def test_dreambooth_old_format(self): + generator = torch.Generator("cpu").manual_seed(0) + + lora_model_id = "hf-internal-testing/lora_dreambooth_dog_example" + card = RepoCard.load(lora_model_id) + base_model_id = card.data.to_dict()["base_model"] + + pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None) + pipe = pipe.to(torch_device) + pipe.load_lora_weights(lora_model_id) + + images = pipe( + "A photo of a sks dog floating in the river", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + + expected = np.array([0.7207, 0.6787, 0.6010, 0.7478, 0.6838, 0.6064, 0.6984, 0.6443, 0.5785]) + + self.assertTrue(np.allclose(images, expected, atol=1e-4)) + + def test_dreambooth_text_encoder_new_format(self): + generator = torch.Generator().manual_seed(0) + + lora_model_id = "hf-internal-testing/lora-trained" + card = RepoCard.load(lora_model_id) + base_model_id = card.data.to_dict()["base_model"] + + pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None) + pipe = pipe.to(torch_device) + pipe.load_lora_weights(lora_model_id) + + images = pipe("A photo of a sks dog", output_type="np", generator=generator, num_inference_steps=2).images + + images = images[0, -3:, -3:, -1].flatten() + + expected = np.array([0.6628, 0.6138, 0.5390, 0.6625, 0.6130, 0.5463, 0.6166, 0.5788, 0.5359]) + + self.assertTrue(np.allclose(images, expected, atol=1e-4)) + + def test_a1111(self): + generator = torch.Generator().manual_seed(0) + + pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None).to( + torch_device + ) + lora_model_id = "hf-internal-testing/civitai-light-shadow-lora" + lora_filename = "light_and_shadow.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + + expected = np.array([0.3743, 0.3893, 0.3835, 0.3891, 0.3949, 0.3649, 0.3858, 0.3802, 0.3245]) + + self.assertTrue(np.allclose(images, expected, atol=1e-4)) + + def test_vanilla_funetuning(self): + generator = torch.Generator().manual_seed(0) + + lora_model_id = "hf-internal-testing/sd-model-finetuned-lora-t4" + card = RepoCard.load(lora_model_id) + base_model_id = card.data.to_dict()["base_model"] + + pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None) + pipe = pipe.to(torch_device) + pipe.load_lora_weights(lora_model_id) + + images = pipe("A pokemon with blue eyes.", output_type="np", generator=generator, num_inference_steps=2).images + + images = images[0, -3:, -3:, -1].flatten() + + expected = np.array([0.7406, 0.699, 0.5963, 0.7493, 0.7045, 0.6096, 0.6886, 0.6388, 0.583]) + + self.assertTrue(np.allclose(images, expected, atol=1e-4))