From 65a5c452fcce914703aa641060b9cc9c26dc3117 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 13 Jul 2023 21:53:35 +0200 Subject: [PATCH 01/41] Allow low precision sd xl --- src/diffusers/models/autoencoder_kl.py | 1 + .../pipeline_stable_diffusion_xl.py | 42 ++++++++++-------- .../pipeline_stable_diffusion_xl_img2img.py | 43 +++++++++++-------- 3 files changed, 48 insertions(+), 38 deletions(-) diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index ddb9bde0ee0a..56288bfc0ca9 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -82,6 +82,7 @@ def __init__( norm_num_groups: int = 32, sample_size: int = 32, scaling_factor: float = 0.18215, + upcast_precision: float = True, ): super().__init__() diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index b3dcf1b67cda..a7eee64724f9 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -537,6 +537,26 @@ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, d add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -799,25 +819,9 @@ def __call__( callback(i, t, latents) # make sure the VAE is in float32 mode, as it overflows in float16 - self.vae.to(dtype=torch.float32) - - use_torch_2_0_or_xformers = isinstance( - self.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - self.vae.post_quant_conv.to(latents.dtype) - self.vae.decoder.conv_in.to(latents.dtype) - self.vae.decoder.mid_block.to(latents.dtype) - else: - latents = latents.float() + if self.vae.dtype == torch.float16 and self.vae.config.upcast_precision: + self.upcast_vae() + latents = latents.to(self.vae.post_quant_conv.dtype) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 7b0cdfad8c0a..815cd5521b64 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -624,6 +624,27 @@ def _get_add_time_ids( return add_time_ids, add_neg_time_ids + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionPipelineXL.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -932,25 +953,9 @@ def __call__( callback(i, t, latents) # make sure the VAE is in float32 mode, as it overflows in float16 - self.vae.to(dtype=torch.float32) - - use_torch_2_0_or_xformers = isinstance( - self.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - self.vae.post_quant_conv.to(latents.dtype) - self.vae.decoder.conv_in.to(latents.dtype) - self.vae.decoder.mid_block.to(latents.dtype) - else: - latents = latents.float() + if self.vae.dtype == torch.float16 and self.vae.config.upcast_precision: + self.upcast_vae() + latents = latents.to(self.vae.post_quant_conv.dtype) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] From 43f842cb75b4ce134873522c3de5dc581564d2f0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 13 Jul 2023 20:30:17 +0000 Subject: [PATCH 02/41] finish --- src/diffusers/models/autoencoder_kl.py | 2 +- .../pipeline_stable_diffusion_xl.py | 5 ++--- .../pipeline_stable_diffusion_xl_img2img.py | 15 ++++++++------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 56288bfc0ca9..33326dd80b1b 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -82,7 +82,7 @@ def __init__( norm_num_groups: int = 32, sample_size: int = 32, scaling_factor: float = 0.18215, - upcast_precision: float = True, + force_upcast: float = True, ): super().__init__() diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index a7eee64724f9..2c513a9cfb32 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -556,7 +556,6 @@ def upcast_vae(self): self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -819,9 +818,9 @@ def __call__( callback(i, t, latents) # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.dtype == torch.float16 and self.vae.config.upcast_precision: + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() - latents = latents.to(self.vae.post_quant_conv.dtype) + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 815cd5521b64..5c6cfdf18dd6 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -542,8 +542,9 @@ def prepare_latents( else: # make sure the VAE is in float32 mode, as it overflows in float16 - image = image.float() - self.vae.to(dtype=torch.float32) + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -559,9 +560,10 @@ def prepare_latents( else: init_latents = self.vae.encode(image).latent_dist.sample(generator) - self.vae.to(dtype) - init_latents = init_latents.to(dtype) + if self.vae.config.force_upcast: + self.vae.to(dtype) + init_latents = init_latents.to(dtype) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: @@ -644,7 +646,6 @@ def upcast_vae(self): self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -953,9 +954,9 @@ def __call__( callback(i, t, latents) # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.dtype == torch.float16 and self.vae.config.upcast_precision: + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() - latents = latents.to(self.vae.post_quant_conv.dtype) + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] From 7171d423b191129be8da95014568fabd3694d54c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 13 Jul 2023 21:24:54 +0000 Subject: [PATCH 03/41] finish --- .../pipeline_stable_diffusion_upscale.py | 42 ++++++++++--------- .../pipeline_stable_diffusion_xl.py | 1 + .../pipeline_stable_diffusion_xl_img2img.py | 2 +- 3 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index a7255424fb46..59aa87f46c19 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -501,6 +501,25 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + @torch.no_grad() def __call__( self, @@ -746,26 +765,9 @@ def __call__( # 10. Post-processing # make sure the VAE is in float32 mode, as it overflows in float16 - self.vae.to(dtype=torch.float32) - - use_torch_2_0_or_xformers = isinstance( - self.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, - ), - ) - - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - self.vae.post_quant_conv.to(latents.dtype) - self.vae.decoder.conv_in.to(latents.dtype) - self.vae.decoder.mid_block.to(latents.dtype) - else: - latents = latents.float() + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) # post-processing if not output_type == "latent": diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 2c513a9cfb32..bcdf47b7c983 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -537,6 +537,7 @@ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, d add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 5c6cfdf18dd6..f14e7fa457e3 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -626,7 +626,7 @@ def _get_add_time_ids( return add_time_ids, add_neg_time_ids - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionPipelineXL.upcast_vae + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) From 933753577c245ab6e8171fc5601bbb3cc70c758f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 16:17:26 +0530 Subject: [PATCH 04/41] feat: initial draft for supporting text encoder lora finetuning for SDXL DreamBooth --- .../dreambooth/train_dreambooth_lora_sdxl.py | 318 ++++++++++++------ src/diffusers/loaders.py | 42 ++- 2 files changed, 246 insertions(+), 114 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 81d87e9c49e2..884448c76b9b 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -16,6 +16,7 @@ import argparse import gc import hashlib +import itertools import logging import math import os @@ -49,7 +50,7 @@ DPMSolverMultistepScheduler, UNet2DConditionModel, ) -from diffusers.loaders import LoraLoaderMixin +from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0 from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available @@ -63,12 +64,7 @@ def save_model_card( - repo_id: str, - images=None, - base_model=str, - train_text_encoder=False, - prompt=str, - repo_folder=None, + repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None ): img_str = "" for i, image in enumerate(images): @@ -96,6 +92,8 @@ def save_model_card( {img_str} LoRA for the text encoder was enabled: {train_text_encoder}. + +Special VAE used for training: {vae_path}. """ with open(os.path.join(repo_folder, "README.md"), "w") as f: f.write(yaml + model_card) @@ -130,6 +128,12 @@ def parse_args(input_args=None): required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", + ) parser.add_argument( "--revision", type=str, @@ -420,9 +424,6 @@ def parse_args(input_args=None): if args.class_prompt is not None: warnings.warn("You need not use --class_prompt without --with_prior_preservation.") - if args.train_text_encoder and args.pre_compute_text_embeddings: - raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") - return args @@ -445,6 +446,7 @@ def __init__( class_prompt_hidden_states=None, instance_unet_added_conditions=None, class_unet_added_conditions=None, + tokenizers=None, ): self.size = size self.center_crop = center_crop @@ -453,6 +455,10 @@ def __init__( self.instance_unet_added_conditions = instance_unet_added_conditions self.class_unet_added_conditions = class_unet_added_conditions + if tokenizers is not None: + self.tokenizer_one = tokenizers[0] + self.tokenizer_two = tokenizers[1] + self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") @@ -496,7 +502,11 @@ def __getitem__(self, index): instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) - example["instance_prompt_ids"] = self.instance_prompt_hidden_states + if self.instance_prompt_hidden_states is not None: + example["instance_prompt_ids"] = self.instance_prompt_hidden_states + else: + example["instance_prompt_tokens_one"] = tokenize_prompt(self.tokenizer_one, self.instance_prompt) + example["instance_prompt_tokens_two"] = tokenize_prompt(self.tokenizer_two, self.instance_prompt) example["instance_added_cond_kwargs"] = self.instance_unet_added_conditions if self.class_data_root: @@ -506,45 +516,69 @@ def __getitem__(self, index): if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) - example["class_prompt_ids"] = self.class_prompt_hidden_states + if self.class_prompt_hidden_states is not None: + example["class_prompt_ids"] = self.class_prompt_hidden_states + else: + example["class_prompt_tokens_one"] = tokenize_prompt(self.tokenizer_one, self.class_prompt) + example["class_prompt_tokens_two"] = tokenize_prompt(self.tokenizer_two, self.class_prompt) example["class_added_cond_kwargs"] = self.class_unet_added_conditions return example -def collate_fn(examples, with_prior_preservation=False): +def collate_fn(examples, train_text_encoder=False, with_prior_preservation=False): has_attention_mask = "instance_attention_mask" in examples[0] - - input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] - add_text_embeds = [example["instance_added_cond_kwargs"]["text_embeds"] for example in examples] add_time_ids = [example["instance_added_cond_kwargs"]["time_ids"] for example in examples] + + if not train_text_encoder: + input_ids = [example["instance_prompt_ids"] for example in examples] + add_text_embeds = [example["instance_added_cond_kwargs"]["text_embeds"] for example in examples] + else: + tokens_one = [example["instance_prompt_tokens_one"] for example in examples] + tokens_two = [example["instance_prompt_tokens_two"] for example in examples] + if has_attention_mask: attention_mask = [example["instance_attention_mask"] for example in examples] # Concat class and instance examples for prior preservation. # We do this to avoid doing two forward passes. if with_prior_preservation: - input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] - add_text_embeds += [example["class_added_cond_kwargs"]["text_embeds"] for example in examples] add_time_ids += [example["class_added_cond_kwargs"]["time_ids"] for example in examples] + if not train_text_encoder: + input_ids += [example["class_prompt_ids"] for example in examples] + add_text_embeds += [example["class_added_cond_kwargs"]["text_embeds"] for example in examples] + else: + tokens_one += [example["class_prompt_tokens_one"] for example in examples] + tokens_two += [example["class_prompt_tokens_two"] for example in examples] + if has_attention_mask: attention_mask += [example["class_attention_mask"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - - input_ids = torch.cat(input_ids, dim=0) - add_text_embeds = torch.cat(add_text_embeds, dim=0) add_time_ids = torch.cat(add_time_ids, dim=0) - batch = { - "input_ids": input_ids, - "pixel_values": pixel_values, - "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids}, - } + if not train_text_encoder: + input_ids = torch.cat(input_ids, dim=0) + add_text_embeds = torch.cat(add_text_embeds, dim=0) + else: + tokens_one = torch.cat(tokens_one, dim=0) + tokens_two = torch.cat(tokens_two, dim=0) + + unet_added_conditions = {"time_ids": add_time_ids} + if not train_text_encoder: + unet_added_conditions.update({"text_embeds": add_text_embeds}) + batch = {"input_ids": input_ids, "pixel_values": pixel_values, "unet_added_conditions": unet_added_conditions} + else: + batch = { + "tokens_one": tokens_one, + "tokens_two": tokens_two, + "pixel_values": pixel_values, + "unet_added_conditions": unet_added_conditions, + } if has_attention_mask: batch["attention_mask"] = attention_mask @@ -569,27 +603,29 @@ def __getitem__(self, index): return example +def tokenize_prompt(tokenizer, prompt): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + return text_input_ids + + # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt -def encode_prompt(text_encoders, tokenizers, prompt): +def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): prompt_embeds_list = [] - for tokenizer, text_encoder in zip(tokenizers, text_encoders): - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) + for i, text_encoder in enumerate(text_encoders): + if tokenizers is not None: + tokenizer = tokenizers[i] + text_input_ids = tokenize_prompt(tokenizer, prompt) + else: + assert text_input_ids_list is not None + text_input_ids = text_input_ids_list[i] prompt_embeds = text_encoder( text_input_ids.to(text_encoder.device), @@ -641,9 +677,6 @@ def main(args): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") import wandb - if args.train_text_encoder: - raise NotImplementedError("Text encoder training not yet supported.") - # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -742,7 +775,14 @@ def main(args): text_encoder_two = text_encoder_cls_two.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision ) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision + ) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) @@ -764,7 +804,10 @@ def main(args): # Move unet, vae and text_encoder to device and cast to weight_dtype # The VAE is in float32 to avoid NaN losses. unet.to(accelerator.device, dtype=weight_dtype) - vae.to(accelerator.device, dtype=torch.float32) + if args.pretrained_vae_model_name_or_path is None: + vae.to(accelerator.device, dtype=torch.float32) + else: + vae.to(accelerator.device, dtype=weight_dtype) text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) @@ -804,16 +847,31 @@ def main(args): unet_lora_parameters.extend(module.parameters()) unet.set_attn_processor(unet_lora_attn_procs) - # unet_lora_layers = AttnProcsLayers(unet.attn_processors) + + # The text encoder comes from 🤗 transformers, so we cannot directly modify it. + # So, instead, we monkey-patch the forward calls of its attention-blocks. + if args.train_text_encoder: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(text_encoder_one, dtype=torch.float32) + text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(text_encoder_two, dtype=torch.float32) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): # there are only two options here. Either are just the unet attn processor layers # or there are the unet and text encoder atten layers unet_lora_layers_to_save = None + text_encoder_one_lora_layers_to_save = None + text_encoder_two_lora_layers_to_save = None for model in models: - unet_lora_layers_to_save = unet_attn_processors_state_dict(model) + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_lora_layers_to_save = unet_attn_processors_state_dict(model) + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model) + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): + text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again weights.pop() @@ -821,25 +879,33 @@ def save_model_hook(models, weights, output_dir): LoraLoaderMixin.save_lora_weights( output_dir, unet_lora_layers=unet_lora_layers_to_save, - text_encoder_lora_layers=None, + text_encoder_lora_layers=[text_encoder_one_lora_layers_to_save, text_encoder_two_lora_layers_to_save], ) def load_model_hook(models, input_dir): unet_ = None - text_encoder_ = None + text_encoder_one_ = None + text_encoder_two_ = None while len(models) > 0: model = models.pop() if isinstance(model, type(accelerator.unwrap_model(unet))): unet_ = model + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + text_encoder_one_ = model + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + text_encoder_two_ = model else: raise ValueError(f"unexpected save model: {model.__class__}") lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir) LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_) LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_ + lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_one_ + ) + LoraLoaderMixin.load_lora_into_text_encoder( + lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_two_ ) accelerator.register_save_state_pre_hook(save_model_hook) @@ -869,7 +935,11 @@ def load_model_hook(models, input_dir): optimizer_class = torch.optim.AdamW # Optimizer creation - params_to_optimize = unet_lora_parameters + params_to_optimize = ( + itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two) + if args.train_text_encoder + else unet_lora_parameters + ) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, @@ -878,48 +948,45 @@ def load_model_hook(models, input_dir): eps=args.adam_epsilon, ) - # We ALWAYS pre-compute the additional condition embeddings needed for SDXL - # UNet as the model is already big and it uses two text encoders. - # TODO: when we add support for text encoder training, will reivist. - tokenizers = [tokenizer_one, tokenizer_two] - text_encoders = [text_encoder_one, text_encoder_two] - - # Here, we compute not just the text embeddings but also the additional embeddings - # needed for the SD XL UNet to operate. - def compute_embeddings(prompt, text_encoders, tokenizers): + # Computes additional embeddings required by the SDXL UNet. + def compute_additional_embeddings(): + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids original_size = (args.resolution, args.resolution) target_size = (args.resolution, args.resolution) crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) - with torch.no_grad(): - prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt) - add_text_embeds = pooled_prompt_embeds + add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) + unet_added_cond_kwargs = {"time_ids": add_time_ids} - # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_time_ids = torch.tensor([add_time_ids]) + return unet_added_cond_kwargs - prompt_embeds = prompt_embeds.to(accelerator.device) - add_text_embeds = add_text_embeds.to(accelerator.device) - add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype) - unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if not args.train_text_encoder: - return prompt_embeds, unet_added_cond_kwargs + def compute_text_embeddings(prompt, text_encoders, tokenizers): + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt) + prompt_embeds = prompt_embeds.to(accelerator.device) + pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + return prompt_embeds, pooled_prompt_embeds - instance_prompt_hidden_states, instance_unet_added_conditions = compute_embeddings( - args.instance_prompt, text_encoders, tokenizers - ) + instance_unet_added_conditions = compute_additional_embeddings() + instance_prompt_hidden_states, instance_pooled_prompt_embeds = None, None + if not args.train_text_encoder: + instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings() + instance_unet_added_conditions.update({"text_embeds": instance_pooled_prompt_embeds}) - class_prompt_hidden_states, class_unet_added_conditions = None, None + class_prompt_hidden_states, class_unet_added_conditions, class_pooled_prompt_embeds = None, None, None if args.with_prior_preservation: - class_prompt_hidden_states, class_unet_added_conditions = compute_embeddings( - args.class_prompt, text_encoders, tokenizers - ) + class_unet_added_conditions = compute_additional_embeddings() + if not args.train_text_encoder: + class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings() + class_unet_added_conditions.update({"text_embeds": class_pooled_prompt_embeds}) - del tokenizers, text_encoders - - gc.collect() - torch.cuda.empty_cache() + del tokenizers, text_encoders + gc.collect() + torch.cuda.empty_cache() # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( @@ -934,13 +1001,14 @@ def compute_embeddings(prompt, text_encoders, tokenizers): class_prompt_hidden_states=class_prompt_hidden_states, instance_unet_added_conditions=instance_unet_added_conditions, class_unet_added_conditions=class_unet_added_conditions, + tokenizers=[tokenizer_one, tokenizer_two] if args.train_text_encoder else None, ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch_size, shuffle=True, - collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + collate_fn=lambda examples: collate_fn(examples, args.train_text_encoder, args.with_prior_preservation), num_workers=args.dataloader_num_workers, ) @@ -954,16 +1022,21 @@ def compute_embeddings(prompt, text_encoders, tokenizers): lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, num_cycles=args.lr_num_cycles, power=args.lr_power, ) # Prepare everything with our `accelerator`. - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, optimizer, train_dataloader, lr_scheduler - ) + if args.train_text_encoder: + unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1030,12 +1103,16 @@ def compute_embeddings(prompt, text_encoders, tokenizers): continue with accelerator.accumulate(unet): - # pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + if args.pretrained_vae_model_name_or_path is None: + pixel_values = batch["pixel_values"] + else: + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) # Convert images to latent space - model_input = vae.encode(batch["pixel_values"]).latent_dist.sample() + model_input = vae.encode(pixel_values).latent_dist.sample() model_input = model_input * vae.config.scaling_factor - model_input = model_input.to(weight_dtype) + if args.pretrained_vae_model_name_or_path is None: + model_input = model_input.to(weight_dtype) # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) @@ -1051,9 +1128,24 @@ def compute_embeddings(prompt, text_encoders, tokenizers): noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) # Predict the noise residual - model_pred = unet( - noisy_model_input, timesteps, batch["input_ids"], added_cond_kwargs=batch["unet_added_conditions"] - ).sample + if not args.train_text_encoder: + model_pred = unet( + noisy_model_input, + timesteps, + batch["input_ids"], + added_cond_kwargs=batch["unet_added_conditions"], + ).sample + else: + prompt_embeds, pooled_prompt_embeds = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=None, + prompt=None, + text_input_ids_list=[batch["tokens_one"], batch["tokens_two"]], + ) + batch["unet_added_conditions"].update({"text_embeds": pooled_prompt_embeds}) + model_pred = unet( + noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=batch["unet_added_conditions"] + ).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": @@ -1081,7 +1173,11 @@ def compute_embeddings(prompt, text_encoders, tokenizers): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = unet_lora_parameters + params_to_clip = ( + itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two) + if args.train_text_encoder + else unet_lora_parameters + ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() @@ -1132,8 +1228,20 @@ def compute_embeddings(prompt, text_encoders, tokenizers): f" {args.validation_prompt}." ) # create pipeline + if not args.train_text_encoder: + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + ) pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, + vae=vae, + text_encoder=accelerator.unwrap(text_encoder_one) if args.train_text_encoder else text_encoder_one, + text_encoder_2=accelerator.unwrap(text_encoder_two) + if args.train_text_encoder + else text_encoder_two, unet=accelerator.unwrap_model(unet), revision=args.revision, torch_dtype=weight_dtype, @@ -1192,13 +1300,20 @@ def compute_embeddings(prompt, text_encoders, tokenizers): LoraLoaderMixin.save_lora_weights( save_directory=args.output_dir, unet_lora_layers=unet_lora_layers, - text_encoder_lora_layers=None, + text_encoder_lora_layers=[text_lora_parameters_one, text_lora_parameters_two] + if args.train_text_encoder + else None, ) # Final inference # Load previous pipeline + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + ) pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype + args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype ) # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it @@ -1250,6 +1365,7 @@ def compute_embeddings(prompt, text_encoders, tokenizers): train_text_encoder=args.train_text_encoder, prompt=args.instance_prompt, repo_folder=args.output_dir, + vae_path=args.vae_path, ) upload_folder( repo_id=repo_id, diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 561ae740738c..642143bf15dd 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -828,6 +828,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di self.load_lora_into_text_encoder( state_dict, network_alpha=network_alpha, text_encoder=self.text_encoder, lora_scale=self.lora_scale ) + is_sdxl = any("text_encoder_2" in key for key in state_dict) + if is_sdxl: + self.load_lora_into_text_encoder( + state_dict, network_alpha=network_alpha, text_encoder=self.text_encoder_2, lora_scale=self.lora_scale + ) @classmethod def lora_state_dict( @@ -1146,7 +1151,9 @@ def save_lora_weights( self, save_directory: Union[str, os.PathLike], unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + text_encoder_lora_layers: Union[ + Dict[str, Union[torch.nn.Module, torch.Tensor]], List[Dict[str, Union[torch.nn.Module, torch.Tensor]]] + ] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -1160,9 +1167,9 @@ def save_lora_weights( Directory to save LoRA parameters to. Will be created if it doesn't exist. unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the UNet. - text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`): + text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`) or `List`: State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text - encoder LoRA state dict because it comes 🤗 Transformers. + encoder LoRA state dict because it comes from 🤗 Transformers. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main @@ -1187,7 +1194,7 @@ def save_function(weights, filename): os.makedirs(save_directory, exist_ok=True) - # Create a flat dictionary. + # Create a flat dictionary and populate. state_dict = {} if unet_lora_layers is not None: weights = ( @@ -1198,16 +1205,25 @@ def save_function(weights, filename): state_dict.update(unet_lora_state_dict) if text_encoder_lora_layers is not None: - weights = ( - text_encoder_lora_layers.state_dict() - if isinstance(text_encoder_lora_layers, torch.nn.Module) - else text_encoder_lora_layers - ) - text_encoder_lora_state_dict = { - f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items() - } - state_dict.update(text_encoder_lora_state_dict) + def get_te_params(te_state_dict): + weights = te_state_dict.state_dict() if isinstance(te_state_dict, torch.nn.Module) else te_state_dict + return weights + + if not isinstance(text_encoder_lora_layers, list): + weights = get_te_params(text_encoder_lora_layers) + text_encoder_lora_state_dict = { + f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items() + } + state_dict.update(text_encoder_lora_state_dict) + else: + for i in range(len(text_encoder_lora_layers)): + weights = get_te_params(text_encoder_lora_layers[i]) + weight_suffix = f"{self.text_encoder_name}_{i+1}" if i > 0 else self.text_encoder_name + text_encoder_lora_state_dict = { + f"{weight_suffix}.{module_name}": param for module_name, param in weights.items() + } + state_dict.update(text_encoder_lora_state_dict) # Save the model if weight_name is None: From b487dfc9aee80b86009ca4a29d6c150d062faa9f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 17:09:06 +0530 Subject: [PATCH 05/41] fix: variable assignments. --- .../dreambooth/train_dreambooth_lora_sdxl.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 884448c76b9b..6f52ab71229f 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -963,6 +963,8 @@ def compute_additional_embeddings(): return unet_added_cond_kwargs if not args.train_text_encoder: + tokenizers = [tokenizer_one, tokenizer_two] + text_encoders = [text_encoder_one, text_encoder_two] def compute_text_embeddings(prompt, text_encoders, tokenizers): with torch.no_grad(): @@ -974,19 +976,24 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): instance_unet_added_conditions = compute_additional_embeddings() instance_prompt_hidden_states, instance_pooled_prompt_embeds = None, None if not args.train_text_encoder: - instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings() + instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( + args.instance_prompt, text_encoders, tokenizers + ) instance_unet_added_conditions.update({"text_embeds": instance_pooled_prompt_embeds}) class_prompt_hidden_states, class_unet_added_conditions, class_pooled_prompt_embeds = None, None, None if args.with_prior_preservation: class_unet_added_conditions = compute_additional_embeddings() if not args.train_text_encoder: - class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings() + class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( + args.class_prompt, text_encoders, tokenizers + ) class_unet_added_conditions.update({"text_embeds": class_pooled_prompt_embeds}) - del tokenizers, text_encoders - gc.collect() - torch.cuda.empty_cache() + if not args.train_text_encoder: + del tokenizers, text_encoders + gc.collect() + torch.cuda.empty_cache() # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( From c51e5598461fdd7aa28a21c3843bf26a5ed8afb7 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 17:33:52 +0530 Subject: [PATCH 06/41] add: autocast block. --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 6f52ab71229f..0e0414817788 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1102,6 +1102,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): for epoch in range(first_epoch, args.num_train_epochs): unet.train() + if args.train_text_encoder: + text_encoder_one.train() + text_encoder_two.train() for step, batch in enumerate(train_dataloader): # Skip steps until we reach the resumed step if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: @@ -1276,9 +1279,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None pipeline_args = {"prompt": args.validation_prompt} - images = [ - pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images) - ] + with torch.cuda.amp.autocast(): + images = [ + pipeline(**pipeline_args, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] for tracker in accelerator.trackers: if tracker.name == "tensorboard": From 9d23e30c2ce69ca2f3d851f9852dff37bca153d0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 17:39:42 +0530 Subject: [PATCH 07/41] add debugging --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 0e0414817788..492a22d52334 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1119,6 +1119,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): pixel_values = batch["pixel_values"].to(dtype=weight_dtype) # Convert images to latent space + print(f"vae: {vae.dtype} pixel_values: {pixel_values.dtype}") model_input = vae.encode(pixel_values).latent_dist.sample() model_input = model_input * vae.config.scaling_factor if args.pretrained_vae_model_name_or_path is None: From ea285db0705ccd28d70fd82f8c7946afecc1ab30 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 17:49:28 +0530 Subject: [PATCH 08/41] vae dtype hell --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 492a22d52334..261859033205 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1115,11 +1115,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): with accelerator.accumulate(unet): if args.pretrained_vae_model_name_or_path is None: pixel_values = batch["pixel_values"] + if vae.dtype != weight_dtype: + vae.to(dtype=weight_dtype) else: pixel_values = batch["pixel_values"].to(dtype=weight_dtype) # Convert images to latent space - print(f"vae: {vae.dtype} pixel_values: {pixel_values.dtype}") model_input = vae.encode(pixel_values).latent_dist.sample() model_input = model_input * vae.config.scaling_factor if args.pretrained_vae_model_name_or_path is None: From b0ed1b666a81fce7e7ec1a15d860084918c00ef4 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 17:53:52 +0530 Subject: [PATCH 09/41] fix: vae dtype hell. --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 261859033205..293c6afe4db6 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1115,12 +1115,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): with accelerator.accumulate(unet): if args.pretrained_vae_model_name_or_path is None: pixel_values = batch["pixel_values"] - if vae.dtype != weight_dtype: - vae.to(dtype=weight_dtype) else: pixel_values = batch["pixel_values"].to(dtype=weight_dtype) # Convert images to latent space + if args.pretrained_vae_model_name_or_path and vae.dtype != weight_dtype: + vae.to(dtype=weight_dtype) model_input = vae.encode(pixel_values).latent_dist.sample() model_input = model_input * vae.config.scaling_factor if args.pretrained_vae_model_name_or_path is None: From 20a9186b9f7d1e3d161c80fa6c5b4b598aff786c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 17:58:24 +0530 Subject: [PATCH 10/41] fix: vae dtype hell 3. --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 293c6afe4db6..16def33e0804 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1117,10 +1117,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): pixel_values = batch["pixel_values"] else: pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + if vae.dtype != weight_dtype: + vae.to(dtype=weight_dtype) # Convert images to latent space - if args.pretrained_vae_model_name_or_path and vae.dtype != weight_dtype: - vae.to(dtype=weight_dtype) + # if args.pretrained_vae_model_name_or_path is not None and vae.dtype != weight_dtype: + # vae.to(dtype=weight_dtype) model_input = vae.encode(pixel_values).latent_dist.sample() model_input = model_input * vae.config.scaling_factor if args.pretrained_vae_model_name_or_path is None: From 9d3e606e551e3210186698d1db173391c2d62d0e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 18:03:18 +0530 Subject: [PATCH 11/41] clean up --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 16def33e0804..a3d54642de7f 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1121,8 +1121,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): vae.to(dtype=weight_dtype) # Convert images to latent space - # if args.pretrained_vae_model_name_or_path is not None and vae.dtype != weight_dtype: - # vae.to(dtype=weight_dtype) model_input = vae.encode(pixel_values).latent_dist.sample() model_input = model_input * vae.config.scaling_factor if args.pretrained_vae_model_name_or_path is None: From 35b8daec2347fbd926c2ca9d7bf6c3d39e151720 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 18:17:56 +0530 Subject: [PATCH 12/41] lora text encoder loader. --- src/diffusers/loaders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 642143bf15dd..c4dbf04b7881 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -50,7 +50,7 @@ import safetensors if is_transformers_available(): - from transformers import CLIPTextModel, PreTrainedModel, PreTrainedTokenizer + from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer logger = logging.get_logger(__name__) @@ -96,7 +96,7 @@ def forward(self, input): def text_encoder_attn_modules(text_encoder): attn_modules = [] - if isinstance(text_encoder, CLIPTextModel): + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): for i, layer in enumerate(text_encoder.text_model.encoder.layers): name = f"text_model.encoder.layers.{i}.self_attn" mod = layer.self_attn From 9c305b1118744194d760a9cd943e2c29056543a3 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 18:29:58 +0530 Subject: [PATCH 13/41] fix: unwrapping models. --- .../dreambooth/train_dreambooth_lora_sdxl.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index a3d54642de7f..25179233cdba 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -894,7 +894,7 @@ def load_model_hook(models, input_dir): unet_ = model elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): text_encoder_one_ = model - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): text_encoder_two_ = model else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1250,8 +1250,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, - text_encoder=accelerator.unwrap(text_encoder_one) if args.train_text_encoder else text_encoder_one, - text_encoder_2=accelerator.unwrap(text_encoder_two) + text_encoder=accelerator.unwrap_model(text_encoder_one) + if args.train_text_encoder + else text_encoder_one, + text_encoder_2=accelerator.unwrap_model(text_encoder_two) if args.train_text_encoder else text_encoder_two, unet=accelerator.unwrap_model(unet), @@ -1311,12 +1313,22 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): unet = unet.to(torch.float32) unet_lora_layers = unet_attn_processors_state_dict(unet) + if args.train_text_encoder: + text_encoder_one = accelerator.unwrap_model(text_encoder_one) + text_encoder_one = text_encoder_one.to(torch.float32) + text_encoder_two = accelerator.unwrap_model(text_encoder_two) + text_encoder_two = text_encoder_two.to(torch.float32) + text_encoder_lora_layers = [ + text_encoder_lora_state_dict(text_encoder_one), + text_encoder_lora_state_dict(text_encoder_two), + ] + else: + text_encoder_lora_layers = None + LoraLoaderMixin.save_lora_weights( save_directory=args.output_dir, unet_lora_layers=unet_lora_layers, - text_encoder_lora_layers=[text_lora_parameters_one, text_lora_parameters_two] - if args.train_text_encoder - else None, + text_encoder_lora_layers=text_encoder_lora_layers, ) # Final inference From 4afb79391e3eeee4b27b72bcbfc4aecf6358b07b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 18:54:51 +0530 Subject: [PATCH 14/41] add: tests. --- examples/test_examples.py | 36 ++++++++++++++++++++++++++++++++++++ src/diffusers/loaders.py | 11 ++++++++--- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/examples/test_examples.py b/examples/test_examples.py index cc3b3fbf7478..646e4ec74276 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -385,6 +385,42 @@ def test_dreambooth_lora_sdxl(self): starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys()) self.assertTrue(starts_with_unet) + def test_dreambooth_lora_sdxl_with_text_encoder(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --train_text_encoder + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names. + keys = lora_state_dict.keys() + starts_with_unet = all( + k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys + ) + self.assertTrue(starts_with_unet) + def test_custom_diffusion(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index c4dbf04b7881..9319228217a8 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1032,12 +1032,17 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo keys = list(state_dict.keys()) if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. - text_encoder_keys = [k for k in keys if k.startswith(cls.text_encoder_name)] + if isinstance(text_encoder, CLIPTextModel): + prefix = cls.text_encoder_name + text_encoder_keys = [k for k in keys if k.startswith(prefix)] + elif isinstance(text_encoder, CLIPTextModelWithProjection): + prefix = f"{cls.text_encoder_name}_2" + text_encoder_keys = [k for k in keys if k.startswith(prefix)] text_encoder_lora_state_dict = { - k.replace(f"{cls.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys + k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys } if len(text_encoder_lora_state_dict) > 0: - logger.info(f"Loading {cls.text_encoder_name}.") + logger.info(f"Loading {prefix}.") if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()): # Convert from the old naming convention to the new naming convention. From 42fb43362add441e1c339f021dce8853fc661eba Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 19:03:17 +0530 Subject: [PATCH 15/41] docs. --- examples/dreambooth/README_sdxl.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/dreambooth/README_sdxl.md b/examples/dreambooth/README_sdxl.md index 51133c24bc1a..429df55a5d1d 100644 --- a/examples/dreambooth/README_sdxl.md +++ b/examples/dreambooth/README_sdxl.md @@ -164,6 +164,16 @@ Here's a side-by-side comparison of the with and without Refiner pipeline output |---|---| | ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/sks_dog.png) | ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/refined_sks_dog.png) | +### Training with text encoder(s) + +Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind: + +* SDXL has two text encoders. So, we fine-tune both using LoRA. +* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory. + +### Specifying a better VAE + +SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). ## Notes In our experiments we found that SDXL yields very good initial results using the default settings of the script. We didn't explore further hyper-parameter tuning experiments, but we do encourage the community to explore this avenue further and share their results with us 🤗 From 0f9588715f6dac33c930692b164a03571f3848c4 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 19:06:39 +0530 Subject: [PATCH 16/41] handle unexpected keys. --- src/diffusers/loaders.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 9319228217a8..db7846ea7a19 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1095,9 +1095,12 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) if len(load_state_dict_results.unexpected_keys) != 0: - raise ValueError( - f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" - ) + unexpected_keys = load_state_dict_results.unexpected_keys + is_sdxl_ckpt = any("text_encoder_" in k for k in unexpected_keys) + if not is_sdxl_ckpt: + raise ValueError( + f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" + ) @property def lora_scale(self) -> float: From ede8ca2769347db13cff309ef00ce862179ea26e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 19:24:32 +0530 Subject: [PATCH 17/41] fix vae dtype in the final inference. --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 25179233cdba..c6bfc66d808f 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1338,6 +1338,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision, ) + if args.pretrained_vae_model_name_or_path is not None: + vae.to(dtype=weight_dtype) pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype ) From 34b536c891f6424167d778ad627c50764ada52ff Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 19:31:01 +0530 Subject: [PATCH 18/41] fix scope problem. --- src/diffusers/loaders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index db7846ea7a19..6cac7c0a4f7d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1034,10 +1034,10 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo # Load the layers corresponding to text encoder and make necessary adjustments. if isinstance(text_encoder, CLIPTextModel): prefix = cls.text_encoder_name - text_encoder_keys = [k for k in keys if k.startswith(prefix)] elif isinstance(text_encoder, CLIPTextModelWithProjection): prefix = f"{cls.text_encoder_name}_2" - text_encoder_keys = [k for k in keys if k.startswith(prefix)] + + text_encoder_keys = [k for k in keys if k.startswith(prefix)] text_encoder_lora_state_dict = { k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys } From ad2617443c897585b57e845725f721cd88193784 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 19:32:46 +0530 Subject: [PATCH 19/41] fix: save_model_card args. --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index c6bfc66d808f..d836d0ebabab 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1393,7 +1393,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): train_text_encoder=args.train_text_encoder, prompt=args.instance_prompt, repo_folder=args.output_dir, - vae_path=args.vae_path, + vae_path=args.pretrained_vae_model_name_or_path, ) upload_folder( repo_id=repo_id, From c5a95d66db65fc8082df20ea6d1a90e532ad9c81 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 19:54:16 +0530 Subject: [PATCH 20/41] initialize: prefix to None. --- src/diffusers/loaders.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 6cac7c0a4f7d..06eb3234e2bf 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1030,6 +1030,8 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) + prefix = None + if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. if isinstance(text_encoder, CLIPTextModel): From 63f62b4194066465dcd345954d5a292953e17a5a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 20:03:41 +0530 Subject: [PATCH 21/41] fix: dtype issues. --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index d836d0ebabab..f7106736d58b 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1337,9 +1337,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision, + torch_dtype=weight_dtype ) - if args.pretrained_vae_model_name_or_path is not None: - vae.to(dtype=weight_dtype) pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype ) From 2d815d2a2cbec534da71d7015465aab4ce5ba585 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 20:07:36 +0530 Subject: [PATCH 22/41] apply gixes. --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 2 +- src/diffusers/loaders.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index f7106736d58b..3da4a63d2ffc 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1337,7 +1337,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision, - torch_dtype=weight_dtype + torch_dtype=weight_dtype, ) pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 06eb3234e2bf..9d6c5d8a6720 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1031,7 +1031,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo # their prefixes. keys = list(state_dict.keys()) prefix = None - + if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. if isinstance(text_encoder, CLIPTextModel): From 3bb3d4fa667f363c7428b43f7e7309982879ab47 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 21:33:02 +0530 Subject: [PATCH 23/41] debgging. --- src/diffusers/loaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 2280cc494d26..d6332f3cb105 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1032,6 +1032,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) + print(f"From loaders: {keys}") prefix = None if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): From 6b496d5e067c30566b3f4e3767a019863754ed92 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 21:41:07 +0530 Subject: [PATCH 24/41] debugging --- src/diffusers/loaders.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index d6332f3cb105..410db7a94943 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1041,7 +1041,9 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo prefix = cls.text_encoder_name elif isinstance(text_encoder, CLIPTextModelWithProjection): prefix = f"{cls.text_encoder_name}_2" - + for k in keys: + if "text_encoder" in k: + print(f"From loader: {k}") text_encoder_keys = [k for k in keys if k.startswith(prefix)] text_encoder_lora_state_dict = { k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys From d4a68a929fff73abdf9e68b8297f0c21aed859cf Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 21:42:50 +0530 Subject: [PATCH 25/41] debugging --- src/diffusers/loaders.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 410db7a94943..0e63fd253f9a 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1037,13 +1037,14 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. + for k in keys: + if "text_encoder" in k: + print(f"From loader: {k}") if isinstance(text_encoder, CLIPTextModel): prefix = cls.text_encoder_name elif isinstance(text_encoder, CLIPTextModelWithProjection): prefix = f"{cls.text_encoder_name}_2" - for k in keys: - if "text_encoder" in k: - print(f"From loader: {k}") + text_encoder_keys = [k for k in keys if k.startswith(prefix)] text_encoder_lora_state_dict = { k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys From b70b0036d0c33cfa55460b9234d0c8a2a463fa17 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 21:45:03 +0530 Subject: [PATCH 26/41] debugging --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 0e63fd253f9a..31acb599d13c 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1034,7 +1034,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo keys = list(state_dict.keys()) print(f"From loaders: {keys}") prefix = None - + print(f"From loaders: {all(key.startswith(cls.unet_name) for key in keys)}") if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. for k in keys: From fe9ca1456184c9cc1157e63edb9f34d2cd56963f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 21:49:34 +0530 Subject: [PATCH 27/41] debugging --- src/diffusers/loaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 31acb599d13c..d1e8dd5bb7ab 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1035,6 +1035,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo print(f"From loaders: {keys}") prefix = None print(f"From loaders: {all(key.startswith(cls.unet_name) for key in keys)}") + print(f"From loaders: {all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys)}") if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. for k in keys: From 3d968bbbf7ccda77632b95878d68346327c34c36 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 21:57:47 +0530 Subject: [PATCH 28/41] debugging --- src/diffusers/loaders.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index d1e8dd5bb7ab..a5dc539856fd 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1032,20 +1032,15 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) - print(f"From loaders: {keys}") prefix = None - print(f"From loaders: {all(key.startswith(cls.unet_name) for key in keys)}") - print(f"From loaders: {all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys)}") - if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): + + if any(cls.text_encoder_name in key for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. - for k in keys: - if "text_encoder" in k: - print(f"From loader: {k}") if isinstance(text_encoder, CLIPTextModel): prefix = cls.text_encoder_name elif isinstance(text_encoder, CLIPTextModelWithProjection): prefix = f"{cls.text_encoder_name}_2" - + text_encoder_keys = [k for k in keys if k.startswith(prefix)] text_encoder_lora_state_dict = { k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys From bcac03226c5f27d4ed0a611751ee5855e576cea9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jul 2023 14:32:14 +0530 Subject: [PATCH 29/41] add: fast tests. --- src/diffusers/loaders.py | 2 + tests/models/test_lora_layers.py | 188 ++++++++++++++++++++++++++++++- 2 files changed, 186 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index a5dc539856fd..891b0def5808 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1114,6 +1114,8 @@ def lora_scale(self) -> float: def _remove_text_encoder_monkey_patch(self): self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) + if hasattr(self, "text_encoder_2"): + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) @classmethod def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 1396561367e0..0b5693f8f972 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -21,9 +21,16 @@ import torch.nn as nn import torch.nn.functional as F from huggingface_hub.repocard import RepoCard -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer - -from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + EulerDiscreteScheduler, + StableDiffusionPipeline, + StableDiffusionXLPipeline, + UNet2DConditionModel, +) from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin, PatchedLoraProjection, text_encoder_attn_modules from diffusers.models.attention_processor import ( Attention, @@ -150,7 +157,7 @@ def get_dummy_components(self): } lora_components = { "unet_lora_layers": unet_lora_layers, - "text_encoder_lora_layers": text_encoder_lora_layers, + "text_encoder_one_lora_layers": text_encoder_lora_layers, "unet_lora_attn_procs": unet_lora_attn_procs, } return pipeline_components, lora_components @@ -495,6 +502,179 @@ def test_lora_save_load_with_xformers(self): self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) +class SDXLLoraLoaderMixinTests(unittest.TestCase): + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + ) + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + steps_offset=1, + beta_schedule="scaled_linear", + timestep_spacing="leading", + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=32, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True) + + text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True) + + unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet) + text_encoder_one_lora_layers = create_text_encoder_lora_layers(text_encoder) + text_encoder_two_lora_layers = create_text_encoder_lora_layers(text_encoder_2) + + pipeline_components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + } + lora_components = { + "unet_lora_layers": unet_lora_layers, + "text_encoder_one_lora_layers": text_encoder_one_lora_layers, + "text_encoder_two_lora_layers": text_encoder_two_lora_layers, + "unet_lora_attn_procs": unet_lora_attn_procs, + } + return pipeline_components, lora_components + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + def test_lora_save_load(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs() + + original_images = sd_pipe(**pipeline_inputs).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=[ + lora_components["text_encoder_one_lora_layers"], + lora_components["text_encoder_two_lora_layers"], + ], + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + + def test_unload_lora(self): + pipeline_components, lora_components = self.get_dummy_components() + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + + original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + # Emulate training. + set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) + + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=[ + lora_components["text_encoder_one_lora_layers"], + lora_components["text_encoder_two_lora_layers"], + ], + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Unload LoRA parameters. + sd_pipe.unload_lora_weights() + original_images_two = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + orig_image_slice_two = original_images_two[0, -3:, -3:, -1] + + assert not np.allclose( + orig_image_slice, lora_image_slice + ), "LoRA parameters should lead to a different image slice." + assert not np.allclose( + orig_image_slice_two, lora_image_slice + ), "LoRA parameters should lead to a different image slice." + assert np.allclose( + orig_image_slice, orig_image_slice_two, atol=1e-3 + ), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters." + + @slow @require_torch_gpu class LoraIntegrationTests(unittest.TestCase): From ff3f27f55f9bd0e3e8a741022967a863fa2101a7 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jul 2023 14:44:56 +0530 Subject: [PATCH 30/41] pre-tokenize. --- .../dreambooth/train_dreambooth_lora_sdxl.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 3da4a63d2ffc..1c9817e3e284 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -490,6 +490,14 @@ def __init__( ] ) + # Pre-tokenize so that we don't have to do it each time dataloader is iterated. + if tokenizers is not None: + self.instance_prompt_tokens_one = tokenize_prompt(self.tokenizer_one, self.instance_prompt) + self.instance_prompt_tokens_two = tokenize_prompt(self.tokenizer_two, self.instance_prompt) + if self.class_prompt is not None: + self.class_prompt_tokens_one = tokenize_prompt(self.tokenizer_one, self.class_prompt) + self.class_prompt_tokens_two = tokenize_prompt(self.tokenizer_two, self.class_prompt) + def __len__(self): return self._length @@ -505,8 +513,8 @@ def __getitem__(self, index): if self.instance_prompt_hidden_states is not None: example["instance_prompt_ids"] = self.instance_prompt_hidden_states else: - example["instance_prompt_tokens_one"] = tokenize_prompt(self.tokenizer_one, self.instance_prompt) - example["instance_prompt_tokens_two"] = tokenize_prompt(self.tokenizer_two, self.instance_prompt) + example["instance_prompt_tokens_one"] = self.instance_prompt_tokens_one + example["instance_prompt_tokens_two"] = self.instance_prompt_tokens_two example["instance_added_cond_kwargs"] = self.instance_unet_added_conditions if self.class_data_root: @@ -519,8 +527,8 @@ def __getitem__(self, index): if self.class_prompt_hidden_states is not None: example["class_prompt_ids"] = self.class_prompt_hidden_states else: - example["class_prompt_tokens_one"] = tokenize_prompt(self.tokenizer_one, self.class_prompt) - example["class_prompt_tokens_two"] = tokenize_prompt(self.tokenizer_two, self.class_prompt) + example["class_prompt_tokens_one"] = self.class_prompt_tokens_one + example["class_prompt_tokens_two"] = self.class_prompt_tokens_two example["class_added_cond_kwargs"] = self.class_unet_added_conditions return example From 0c335661931d311fcef1d7ae68caedd21af53f4e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jul 2023 15:33:31 +0530 Subject: [PATCH 31/41] address: will's comments. --- src/diffusers/loaders.py | 39 ++++++++----------- .../pipeline_stable_diffusion_xl.py | 19 +++++++++ 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 891b0def5808..d24e89125b8a 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -830,11 +830,6 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di self.load_lora_into_text_encoder( state_dict, network_alpha=network_alpha, text_encoder=self.text_encoder, lora_scale=self.lora_scale ) - is_sdxl = any("text_encoder_2" in key for key in state_dict) - if is_sdxl: - self.load_lora_into_text_encoder( - state_dict, network_alpha=network_alpha, text_encoder=self.text_encoder_2, lora_scale=self.lora_scale - ) @classmethod def lora_state_dict( @@ -1011,7 +1006,7 @@ def load_lora_into_unet(cls, state_dict, network_alpha, unet): warnings.warn(warn_message) @classmethod - def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0): + def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, prefix=None, lora_scale=1.0): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1032,15 +1027,10 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) - prefix = None + prefix = cls.text_encoder_name if prefix is None else prefix if any(cls.text_encoder_name in key for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. - if isinstance(text_encoder, CLIPTextModel): - prefix = cls.text_encoder_name - elif isinstance(text_encoder, CLIPTextModelWithProjection): - prefix = f"{cls.text_encoder_name}_2" - text_encoder_keys = [k for k in keys if k.startswith(prefix)] text_encoder_lora_state_dict = { k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys @@ -1099,12 +1089,9 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) if len(load_state_dict_results.unexpected_keys) != 0: - unexpected_keys = load_state_dict_results.unexpected_keys - is_sdxl_ckpt = any("text_encoder_" in k for k in unexpected_keys) - if not is_sdxl_ckpt: - raise ValueError( - f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" - ) + raise ValueError( + f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" + ) @property def lora_scale(self) -> float: @@ -1170,11 +1157,12 @@ def save_lora_weights( ] = None, is_main_process: bool = True, weight_name: str = None, + text_encoder_weight_prefix: str = None, save_function: Callable = None, safe_serialization: bool = False, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the UNet and text encoder(s). Arguments: save_directory (`str` or `os.PathLike`): @@ -1188,6 +1176,10 @@ def save_lora_weights( Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. + weight_name (`str`, *optional*, defaults to `None`): + Exact name of the weight file to use during serialization (`my_trained_lora_weights.bin`, for example). + text_encoder_weight_prefix (`str`, *optional*, defaults to `None`): + Prefix to use for serializing the LoRA parameters corresponding to the text encoder. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable @@ -1219,6 +1211,9 @@ def save_function(weights, filename): state_dict.update(unet_lora_state_dict) if text_encoder_lora_layers is not None: + text_encoder_weight_prefix = ( + self.text_encoder_name if text_encoder_weight_prefix is None else text_encoder_weight_prefix + ) def get_te_params(te_state_dict): weights = te_state_dict.state_dict() if isinstance(te_state_dict, torch.nn.Module) else te_state_dict @@ -1227,15 +1222,15 @@ def get_te_params(te_state_dict): if not isinstance(text_encoder_lora_layers, list): weights = get_te_params(text_encoder_lora_layers) text_encoder_lora_state_dict = { - f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items() + f"{text_encoder_weight_prefix}.{module_name}": param for module_name, param in weights.items() } state_dict.update(text_encoder_lora_state_dict) else: for i in range(len(text_encoder_lora_layers)): weights = get_te_params(text_encoder_lora_layers[i]) - weight_suffix = f"{self.text_encoder_name}_{i+1}" if i > 0 else self.text_encoder_name + modified_prefix = f"{text_encoder_weight_prefix}_{i+1}" if i > 0 else text_encoder_weight_prefix text_encoder_lora_state_dict = { - f"{weight_suffix}.{module_name}": param for module_name, param in weights.items() + f"{modified_prefix}.{module_name}": param for module_name, param in weights.items() } state_dict.update(text_encoder_lora_state_dict) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index bcdf47b7c983..ebb288792e9e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -840,3 +840,22 @@ def __call__( return (image,) return StableDiffusionXLPipelineOutput(images=image) + + # Overrride to properly handle the loading and unloading of the additional text encoder. + def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet) + self.load_lora_into_text_encoder( + state_dict, + network_alpha=network_alpha, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + ) + self.load_lora_weights( + state_dict, + network_alpha=network_alpha, + text_encoder=self.text_encoder_2, + prefix="text_encoder_2", + lora_scale=self.lora_scale, + ) From 43d70464c593c4535fa1922ff3f38203e2828856 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jul 2023 16:37:06 +0530 Subject: [PATCH 32/41] fix: loader and tests. --- src/diffusers/loaders.py | 9 ++++++--- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 2 +- tests/models/test_lora_layers.py | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index d24e89125b8a..e5817a1f9db5 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1089,9 +1089,12 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, pr load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) if len(load_state_dict_results.unexpected_keys) != 0: - raise ValueError( - f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" - ) + unexpected_keys = load_state_dict_results.unexpected_keys + is_sdxl_ckpt = any("text_encoder_" in k for k in unexpected_keys) + if not is_sdxl_ckpt: + raise ValueError( + f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" + ) @property def lora_scale(self) -> float: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index ebb288792e9e..d714e41ae6e7 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -852,7 +852,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di prefix="text_encoder", lora_scale=self.lora_scale, ) - self.load_lora_weights( + self.load_lora_into_text_encoder( state_dict, network_alpha=network_alpha, text_encoder=self.text_encoder_2, diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 0b5693f8f972..ba3efb30a0a3 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -157,7 +157,7 @@ def get_dummy_components(self): } lora_components = { "unet_lora_layers": unet_lora_layers, - "text_encoder_one_lora_layers": text_encoder_lora_layers, + "text_encoder_lora_layers": text_encoder_lora_layers, "unet_lora_attn_procs": unet_lora_attn_procs, } return pipeline_components, lora_components From 8d5d5b15f37dc4ab77823c7dea6cfd70f22d7870 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jul 2023 17:36:15 +0530 Subject: [PATCH 33/41] fix: dataloader. --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 1c9817e3e284..13de7645a42f 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -494,7 +494,7 @@ def __init__( if tokenizers is not None: self.instance_prompt_tokens_one = tokenize_prompt(self.tokenizer_one, self.instance_prompt) self.instance_prompt_tokens_two = tokenize_prompt(self.tokenizer_two, self.instance_prompt) - if self.class_prompt is not None: + if class_data_root is not None: self.class_prompt_tokens_one = tokenize_prompt(self.tokenizer_one, self.class_prompt) self.class_prompt_tokens_two = tokenize_prompt(self.tokenizer_two, self.class_prompt) From 0d77b53999b744362fafd108938f9dae2f0dcca5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 21 Jul 2023 11:45:07 +0530 Subject: [PATCH 34/41] simplify dataloader. --- .../dreambooth/train_dreambooth_lora_sdxl.py | 142 +++++------------- 1 file changed, 38 insertions(+), 104 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 13de7645a42f..2a239f30c9f2 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -430,34 +430,19 @@ def parse_args(input_args=None): class DreamBoothDataset(Dataset): """ A dataset to prepare the instance and class images with the prompts for fine-tuning the model. - It pre-processes the images and the tokenizes prompts. + It pre-processes the images. """ def __init__( self, instance_data_root, - instance_prompt, class_data_root=None, - class_prompt=None, class_num=None, size=1024, center_crop=False, - instance_prompt_hidden_states=None, - class_prompt_hidden_states=None, - instance_unet_added_conditions=None, - class_unet_added_conditions=None, - tokenizers=None, ): self.size = size self.center_crop = center_crop - self.instance_prompt_hidden_states = instance_prompt_hidden_states - self.class_prompt_hidden_states = class_prompt_hidden_states - self.instance_unet_added_conditions = instance_unet_added_conditions - self.class_unet_added_conditions = class_unet_added_conditions - - if tokenizers is not None: - self.tokenizer_one = tokenizers[0] - self.tokenizer_two = tokenizers[1] self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): @@ -465,8 +450,6 @@ def __init__( self.instance_images_path = list(Path(instance_data_root).iterdir()) self.num_instance_images = len(self.instance_images_path) - self.instance_prompt = instance_prompt - self._length = self.num_instance_images if class_data_root is not None: self.class_data_root = Path(class_data_root) @@ -477,7 +460,6 @@ def __init__( else: self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) - self.class_prompt = class_prompt else: self.class_data_root = None @@ -490,14 +472,6 @@ def __init__( ] ) - # Pre-tokenize so that we don't have to do it each time dataloader is iterated. - if tokenizers is not None: - self.instance_prompt_tokens_one = tokenize_prompt(self.tokenizer_one, self.instance_prompt) - self.instance_prompt_tokens_two = tokenize_prompt(self.tokenizer_two, self.instance_prompt) - if class_data_root is not None: - self.class_prompt_tokens_one = tokenize_prompt(self.tokenizer_one, self.class_prompt) - self.class_prompt_tokens_two = tokenize_prompt(self.tokenizer_two, self.class_prompt) - def __len__(self): return self._length @@ -510,13 +484,6 @@ def __getitem__(self, index): instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) - if self.instance_prompt_hidden_states is not None: - example["instance_prompt_ids"] = self.instance_prompt_hidden_states - else: - example["instance_prompt_tokens_one"] = self.instance_prompt_tokens_one - example["instance_prompt_tokens_two"] = self.instance_prompt_tokens_two - example["instance_added_cond_kwargs"] = self.instance_unet_added_conditions - if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) class_image = exif_transpose(class_image) @@ -524,73 +491,22 @@ def __getitem__(self, index): if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) - if self.class_prompt_hidden_states is not None: - example["class_prompt_ids"] = self.class_prompt_hidden_states - else: - example["class_prompt_tokens_one"] = self.class_prompt_tokens_one - example["class_prompt_tokens_two"] = self.class_prompt_tokens_two - example["class_added_cond_kwargs"] = self.class_unet_added_conditions return example -def collate_fn(examples, train_text_encoder=False, with_prior_preservation=False): - has_attention_mask = "instance_attention_mask" in examples[0] +def collate_fn(examples, with_prior_preservation=False): pixel_values = [example["instance_images"] for example in examples] - add_time_ids = [example["instance_added_cond_kwargs"]["time_ids"] for example in examples] - - if not train_text_encoder: - input_ids = [example["instance_prompt_ids"] for example in examples] - add_text_embeds = [example["instance_added_cond_kwargs"]["text_embeds"] for example in examples] - else: - tokens_one = [example["instance_prompt_tokens_one"] for example in examples] - tokens_two = [example["instance_prompt_tokens_two"] for example in examples] - - if has_attention_mask: - attention_mask = [example["instance_attention_mask"] for example in examples] # Concat class and instance examples for prior preservation. # We do this to avoid doing two forward passes. if with_prior_preservation: pixel_values += [example["class_images"] for example in examples] - add_time_ids += [example["class_added_cond_kwargs"]["time_ids"] for example in examples] - - if not train_text_encoder: - input_ids += [example["class_prompt_ids"] for example in examples] - add_text_embeds += [example["class_added_cond_kwargs"]["text_embeds"] for example in examples] - else: - tokens_one += [example["class_prompt_tokens_one"] for example in examples] - tokens_two += [example["class_prompt_tokens_two"] for example in examples] - - if has_attention_mask: - attention_mask += [example["class_attention_mask"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - add_time_ids = torch.cat(add_time_ids, dim=0) - - if not train_text_encoder: - input_ids = torch.cat(input_ids, dim=0) - add_text_embeds = torch.cat(add_text_embeds, dim=0) - else: - tokens_one = torch.cat(tokens_one, dim=0) - tokens_two = torch.cat(tokens_two, dim=0) - - unet_added_conditions = {"time_ids": add_time_ids} - if not train_text_encoder: - unet_added_conditions.update({"text_embeds": add_text_embeds}) - batch = {"input_ids": input_ids, "pixel_values": pixel_values, "unet_added_conditions": unet_added_conditions} - else: - batch = { - "tokens_one": tokens_one, - "tokens_two": tokens_two, - "pixel_values": pixel_values, - "unet_added_conditions": unet_added_conditions, - } - - if has_attention_mask: - batch["attention_mask"] = attention_mask + batch = {"pixel_values": pixel_values} return batch @@ -970,6 +886,7 @@ def compute_additional_embeddings(): return unet_added_cond_kwargs + # Pack the embeddings appropriately. if not args.train_text_encoder: tokenizers = [tokenizer_one, tokenizer_two] text_encoders = [text_encoder_one, text_encoder_two] @@ -982,14 +899,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): return prompt_embeds, pooled_prompt_embeds instance_unet_added_conditions = compute_additional_embeddings() - instance_prompt_hidden_states, instance_pooled_prompt_embeds = None, None if not args.train_text_encoder: instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( args.instance_prompt, text_encoders, tokenizers ) instance_unet_added_conditions.update({"text_embeds": instance_pooled_prompt_embeds}) - class_prompt_hidden_states, class_unet_added_conditions, class_pooled_prompt_embeds = None, None, None if args.with_prior_preservation: class_unet_added_conditions = compute_additional_embeddings() if not args.train_text_encoder: @@ -998,32 +913,47 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) class_unet_added_conditions.update({"text_embeds": class_pooled_prompt_embeds}) + # Clear the memory here. if not args.train_text_encoder: del tokenizers, text_encoders gc.collect() torch.cuda.empty_cache() + # Pack the statically computed variables appropriately. This is so that we don't + # have to pass them to the dataloader. + add_time_ids = instance_unet_added_conditions["time_ids"] + if args.with_prior_preservation: + add_time_ids = torch.cat([add_time_ids, class_unet_added_conditions], dim=0) + + if not args.train_text_encoder: + prompt_embeds = instance_prompt_hidden_states + unet_add_text_embeds = instance_pooled_prompt_embeds + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) + else: + tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt) + tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt) + if args.with_prior_preservation: + class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt) + class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt) + tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) + tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, - instance_prompt=args.instance_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None, - class_prompt=args.class_prompt, class_num=args.num_class_images, size=args.resolution, center_crop=args.center_crop, - instance_prompt_hidden_states=instance_prompt_hidden_states, - class_prompt_hidden_states=class_prompt_hidden_states, - instance_unet_added_conditions=instance_unet_added_conditions, - class_unet_added_conditions=class_unet_added_conditions, - tokenizers=[tokenizer_one, tokenizer_two] if args.train_text_encoder else None, ) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch_size, shuffle=True, - collate_fn=lambda examples: collate_fn(examples, args.train_text_encoder, args.with_prior_preservation), + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), num_workers=args.dataloader_num_workers, ) @@ -1125,8 +1055,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): pixel_values = batch["pixel_values"] else: pixel_values = batch["pixel_values"].to(dtype=weight_dtype) - if vae.dtype != weight_dtype: - vae.to(dtype=weight_dtype) # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() @@ -1149,22 +1077,28 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Predict the noise residual if not args.train_text_encoder: + unet_added_conditions = { + "time_ids": add_time_ids.repeat(bsz, 1), + "text_embeds": unet_add_text_embeds.repeat(bsz, 1), + } model_pred = unet( noisy_model_input, timesteps, - batch["input_ids"], - added_cond_kwargs=batch["unet_added_conditions"], + prompt_embeds.repeat(bsz, 1, 1), + added_cond_kwargs=unet_added_conditions, ).sample else: + unet_added_conditions = {"time_ids": add_time_ids.repeat(bsz, 1)} prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two], tokenizers=None, prompt=None, - text_input_ids_list=[batch["tokens_one"], batch["tokens_two"]], + text_input_ids_list=[tokens_one, tokens_two], ) - batch["unet_added_conditions"].update({"text_embeds": pooled_prompt_embeds}) + unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(bsz, 1)}) + prompt_embeds = prompt_embeds.repeat(bsz, 1, 1) model_pred = unet( - noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=batch["unet_added_conditions"] + noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions ).sample # Get the target for loss depending on the prediction type From 91b0c3ac0d54c14475d3ef6e24cb5aaf4f6f907f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 21 Jul 2023 12:17:18 +0530 Subject: [PATCH 35/41] length. --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 2a239f30c9f2..b5fbaa3805e5 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -450,6 +450,7 @@ def __init__( self.instance_images_path = list(Path(instance_data_root).iterdir()) self.num_instance_images = len(self.instance_images_path) + self._length = self.num_instance_images if class_data_root is not None: self.class_data_root = Path(class_data_root) From 52eef75d9ef9d8d5167fec34c3c2df0647fb5973 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 21 Jul 2023 13:27:59 +0530 Subject: [PATCH 36/41] simplification. --- .../dreambooth/train_dreambooth_lora_sdxl.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index b5fbaa3805e5..4fe7ef854394 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -873,21 +873,21 @@ def load_model_hook(models, input_dir): eps=args.adam_epsilon, ) - # Computes additional embeddings required by the SDXL UNet. - def compute_additional_embeddings(): + # Computes additional embeddings/ids required by the SDXL UNet. + # regular text emebddings (when `train_text_encoder` is not True) + # pooled text embeddings + # time ids + + def compute_time_ids(): # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids original_size = (args.resolution, args.resolution) target_size = (args.resolution, args.resolution) crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w) add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = torch.tensor([add_time_ids]) - add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) - unet_added_cond_kwargs = {"time_ids": add_time_ids} - - return unet_added_cond_kwargs + return add_time_ids - # Pack the embeddings appropriately. if not args.train_text_encoder: tokenizers = [tokenizer_one, tokenizer_two] text_encoders = [text_encoder_one, text_encoder_two] @@ -899,20 +899,20 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) return prompt_embeds, pooled_prompt_embeds - instance_unet_added_conditions = compute_additional_embeddings() + # Handle instance prompt. + instance_time_ids = compute_time_ids() if not args.train_text_encoder: instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( args.instance_prompt, text_encoders, tokenizers ) - instance_unet_added_conditions.update({"text_embeds": instance_pooled_prompt_embeds}) + # Handle class prompt for prior-preservation. if args.with_prior_preservation: - class_unet_added_conditions = compute_additional_embeddings() + class_time_ids = compute_time_ids() if not args.train_text_encoder: class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( args.class_prompt, text_encoders, tokenizers ) - class_unet_added_conditions.update({"text_embeds": class_pooled_prompt_embeds}) # Clear the memory here. if not args.train_text_encoder: @@ -922,9 +922,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Pack the statically computed variables appropriately. This is so that we don't # have to pass them to the dataloader. - add_time_ids = instance_unet_added_conditions["time_ids"] + add_time_ids = instance_time_ids if args.with_prior_preservation: - add_time_ids = torch.cat([add_time_ids, class_unet_added_conditions], dim=0) + add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0) if not args.train_text_encoder: prompt_embeds = instance_prompt_hidden_states From ef501c882bf48197ee313a8402bdd54c2281b31c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 21 Jul 2023 13:33:38 +0530 Subject: [PATCH 37/41] make style && make quality --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 4fe7ef854394..4d3ebd229f7a 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -877,7 +877,7 @@ def load_model_hook(models, input_dir): # regular text emebddings (when `train_text_encoder` is not True) # pooled text embeddings # time ids - + def compute_time_ids(): # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids original_size = (args.resolution, args.resolution) From 07a45c88ea4fc58964089feadcb857f9b4ddaa3b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 21 Jul 2023 14:37:58 +0530 Subject: [PATCH 38/41] simplify state_dict munging --- .../dreambooth/train_dreambooth_lora_sdxl.py | 25 ++-- src/diffusers/loaders.py | 109 +++++++++--------- .../pipeline_stable_diffusion_xl.py | 64 +++++++--- 3 files changed, 117 insertions(+), 81 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 4d3ebd229f7a..a774ccb7d6a2 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -46,8 +46,8 @@ from diffusers import ( AutoencoderKL, DDPMScheduler, - DiffusionPipeline, DPMSolverMultistepScheduler, + StableDiffusionXLPipeline, UNet2DConditionModel, ) from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict @@ -635,7 +635,7 @@ def main(args): torch_dtype = torch.float16 elif args.prior_generation_precision == "bf16": torch_dtype = torch.bfloat16 - pipeline = DiffusionPipeline.from_pretrained( + pipeline = StableDiffusionXLPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None, @@ -801,10 +801,11 @@ def save_model_hook(models, weights, output_dir): # make sure to pop weight so that corresponding model is not saved again weights.pop() - LoraLoaderMixin.save_lora_weights( + StableDiffusionXLPipeline.save_lora_weights( output_dir, unet_lora_layers=unet_lora_layers_to_save, - text_encoder_lora_layers=[text_encoder_one_lora_layers_to_save, text_encoder_two_lora_layers_to_save], + text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, + text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, ) def load_model_hook(models, input_dir): @@ -1190,7 +1191,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): text_encoder_two = text_encoder_cls_two.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision ) - pipeline = DiffusionPipeline.from_pretrained( + pipeline = StableDiffusionXLPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, text_encoder=accelerator.unwrap_model(text_encoder_one) @@ -1258,20 +1259,18 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if args.train_text_encoder: text_encoder_one = accelerator.unwrap_model(text_encoder_one) - text_encoder_one = text_encoder_one.to(torch.float32) + text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32)) text_encoder_two = accelerator.unwrap_model(text_encoder_two) - text_encoder_two = text_encoder_two.to(torch.float32) - text_encoder_lora_layers = [ - text_encoder_lora_state_dict(text_encoder_one), - text_encoder_lora_state_dict(text_encoder_two), - ] + text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32)) else: text_encoder_lora_layers = None + text_encoder_2_lora_layers = None - LoraLoaderMixin.save_lora_weights( + StableDiffusionXLPipeline.save_lora_weights( save_directory=args.output_dir, unet_lora_layers=unet_lora_layers, text_encoder_lora_layers=text_encoder_lora_layers, + text_encoder_2_lora_layers=text_encoder_2_lora_layers, ) # Final inference @@ -1282,7 +1281,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): revision=args.revision, torch_dtype=weight_dtype, ) - pipeline = DiffusionPipeline.from_pretrained( + pipeline = StableDiffusionXLPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype ) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 85722d7351be..9629de3b619d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1021,12 +1021,14 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, pr Parameters: state_dict (`dict`): - A standard state dict containing the lora layer parameters. The key shoult be prefixed with an + A standard state dict containing the lora layer parameters. The key should be prefixed with an additional `text_encoder` to distinguish between unet lora layers. network_alpha (`float`): See `LoRALinearLayer` for more details. text_encoder (`CLIPTextModel`): The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. lora_scale (`float`): How much to scale the output of the lora linear layer before it is added with the output of the regular lora layer. @@ -1098,12 +1100,9 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, pr load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) if len(load_state_dict_results.unexpected_keys) != 0: - unexpected_keys = load_state_dict_results.unexpected_keys - is_sdxl_ckpt = any("text_encoder_" in k for k in unexpected_keys) - if not is_sdxl_ckpt: - raise ValueError( - f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" - ) + raise ValueError( + f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" + ) @property def lora_scale(self) -> float: @@ -1113,8 +1112,6 @@ def lora_scale(self) -> float: def _remove_text_encoder_monkey_patch(self): self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) - if hasattr(self, "text_encoder_2"): - self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) @classmethod def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): @@ -1164,56 +1161,36 @@ def save_lora_weights( self, save_directory: Union[str, os.PathLike], unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_lora_layers: Union[ - Dict[str, Union[torch.nn.Module, torch.Tensor]], List[Dict[str, Union[torch.nn.Module, torch.Tensor]]] - ] = None, + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, is_main_process: bool = True, weight_name: str = None, - text_encoder_weight_prefix: str = None, save_function: Callable = None, safe_serialization: bool = False, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder(s). + Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): Directory to save LoRA parameters to. Will be created if it doesn't exist. unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the UNet. - text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`) or `List`: + text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text - encoder LoRA state dict because it comes from 🤗 Transformers. + encoder LoRA state dict because it comes 🤗 Transformers. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. - weight_name (`str`, *optional*, defaults to `None`): - Exact name of the weight file to use during serialization (`my_trained_lora_weights.bin`, for example). - text_encoder_weight_prefix (`str`, *optional*, defaults to `None`): - Prefix to use for serializing the LoRA parameters corresponding to the text encoder. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. """ - if os.path.isfile(save_directory): - logger.error(f"Provided path ({save_directory}) should be a directory, not a file") - return - - if save_function is None: - if safe_serialization: - - def save_function(weights, filename): - return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) - - else: - save_function = torch.save - - os.makedirs(save_directory, exist_ok=True) - - # Create a flat dictionary and populate. + # Create a flat dictionary. state_dict = {} + + # Populate the dictionary. if unet_lora_layers is not None: weights = ( unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers @@ -1223,30 +1200,52 @@ def save_function(weights, filename): state_dict.update(unet_lora_state_dict) if text_encoder_lora_layers is not None: - text_encoder_weight_prefix = ( - self.text_encoder_name if text_encoder_weight_prefix is None else text_encoder_weight_prefix + weights = ( + text_encoder_lora_layers.state_dict() + if isinstance(text_encoder_lora_layers, torch.nn.Module) + else text_encoder_lora_layers ) - def get_te_params(te_state_dict): - weights = te_state_dict.state_dict() if isinstance(te_state_dict, torch.nn.Module) else te_state_dict - return weights + text_encoder_lora_state_dict = { + f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items() + } + state_dict.update(text_encoder_lora_state_dict) + + # Save the model + self.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + @classmethod + def write_lora_layers( + self, + state_dict: Dict[str, torch.Tensor], + save_directory: str, + is_main_process: bool, + weight_name: str, + save_function: Callable, + safe_serialization: bool, + ): + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + if save_function is None: + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) - if not isinstance(text_encoder_lora_layers, list): - weights = get_te_params(text_encoder_lora_layers) - text_encoder_lora_state_dict = { - f"{text_encoder_weight_prefix}.{module_name}": param for module_name, param in weights.items() - } - state_dict.update(text_encoder_lora_state_dict) else: - for i in range(len(text_encoder_lora_layers)): - weights = get_te_params(text_encoder_lora_layers[i]) - modified_prefix = f"{text_encoder_weight_prefix}_{i+1}" if i > 0 else text_encoder_weight_prefix - text_encoder_lora_state_dict = { - f"{modified_prefix}.{module_name}": param for module_name, param in weights.items() - } - state_dict.update(text_encoder_lora_state_dict) + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) - # Save the model if weight_name is None: if safe_serialization: weight_name = LORA_WEIGHT_NAME_SAFE diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 36d499752836..c608c2c7d300 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -814,17 +815,54 @@ def __call__( def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet) - self.load_lora_into_text_encoder( - state_dict, - network_alpha=network_alpha, - text_encoder=self.text_encoder, - prefix="text_encoder", - lora_scale=self.lora_scale, - ) - self.load_lora_into_text_encoder( - state_dict, - network_alpha=network_alpha, - text_encoder=self.text_encoder_2, - prefix="text_encoder_2", - lora_scale=self.lora_scale, + + text_encoder_state_dict = {k: v for k, v in state_dict if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alpha=network_alpha, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + ) + + text_encoder_2_state_dict = {k: v for k, v in state_dict if "text_encoder_2." in k} + if len(text_encoder_2_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_2_state_dict, + network_alpha=network_alpha, + text_encoder=self.text_encoder_2, + prefix="text_encoder_2", + lora_scale=self.lora_scale, + ) + + @classmethod + def save_lora_weights( + self, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, torch.Tensor], + text_encoder_lora_layers: Dict[str, torch.Tensor], + text_encoder_2_lora_layers: Dict[str, torch.Tensor], + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = False, + ): + state_dict = {} + state_dict.update({f"unet.{k}": v for k, v in unet_lora_layers.state_dict()}) + if text_encoder_lora_layers and text_encoder_2_lora_layers: + state_dict.update({f"text_encoder.{k}": v for k, v in text_encoder_lora_layers.state_dict()}) + state_dict.update({f"text_encoder_2.{k}": v for k, v in text_encoder_2_lora_layers.state_dict()}) + + self.write_lora_weights( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, ) + + def _remove_text_encoder_monkey_patch(self): + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) From 6ca45f379bdc4b8910d63b37737a6874098666f0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 21 Jul 2023 14:41:44 +0530 Subject: [PATCH 39/41] fix: tests. --- tests/models/test_lora_layers.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index ba3efb30a0a3..531ca0472634 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -406,7 +406,7 @@ def test_lora_unet_attn_processors(self): ) self.assertIsInstance(module.processor, attn_proc_class) - def test_unload_lora(self): + def test_unload_lora_sd(self): pipeline_components, lora_components = self.get_dummy_components() _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) sd_pipe = StableDiffusionPipeline(**pipeline_components) @@ -614,13 +614,11 @@ def test_lora_save_load(self): orig_image_slice = original_images[0, -3:, -3:, -1] with tempfile.TemporaryDirectory() as tmpdirname: - LoraLoaderMixin.save_lora_weights( + StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=[ - lora_components["text_encoder_one_lora_layers"], - lora_components["text_encoder_two_lora_layers"], - ], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) sd_pipe.load_lora_weights(tmpdirname) @@ -631,7 +629,7 @@ def test_lora_save_load(self): # Outputs shouldn't match. self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) - def test_unload_lora(self): + def test_unload_lora_sdxl(self): pipeline_components, lora_components = self.get_dummy_components() _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) sd_pipe = StableDiffusionXLPipeline(**pipeline_components) @@ -645,13 +643,11 @@ def test_unload_lora(self): set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) with tempfile.TemporaryDirectory() as tmpdirname: - LoraLoaderMixin.save_lora_weights( + StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=[ - lora_components["text_encoder_one_lora_layers"], - lora_components["text_encoder_two_lora_layers"], - ], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) sd_pipe.load_lora_weights(tmpdirname) From 5f4e0899ec91c72197360583e1309a6def7489d9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 21 Jul 2023 15:21:41 +0530 Subject: [PATCH 40/41] fix: state_dict packing. --- src/diffusers/loaders.py | 2 -- .../pipeline_stable_diffusion_xl.py | 25 ++++++++++++------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 9629de3b619d..4e6fea7c5971 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1221,9 +1221,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - @classmethod def write_lora_layers( - self, state_dict: Dict[str, torch.Tensor], save_directory: str, is_main_process: bool, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index c608c2c7d300..0c338a6540dc 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -816,7 +816,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet) - text_encoder_state_dict = {k: v for k, v in state_dict if "text_encoder." in k} + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: self.load_lora_into_text_encoder( text_encoder_state_dict, @@ -826,7 +826,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di lora_scale=self.lora_scale, ) - text_encoder_2_state_dict = {k: v for k, v in state_dict if "text_encoder_2." in k} + text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} if len(text_encoder_2_state_dict) > 0: self.load_lora_into_text_encoder( text_encoder_2_state_dict, @@ -840,21 +840,28 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di def save_lora_weights( self, save_directory: Union[str, os.PathLike], - unet_lora_layers: Dict[str, torch.Tensor], - text_encoder_lora_layers: Dict[str, torch.Tensor], - text_encoder_2_lora_layers: Dict[str, torch.Tensor], + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, safe_serialization: bool = False, ): state_dict = {} - state_dict.update({f"unet.{k}": v for k, v in unet_lora_layers.state_dict()}) + + def pack_weights(layers, prefix): + layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers + layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} + return layers_state_dict + + state_dict.update(pack_weights(unet_lora_layers, "unet")) + if text_encoder_lora_layers and text_encoder_2_lora_layers: - state_dict.update({f"text_encoder.{k}": v for k, v in text_encoder_lora_layers.state_dict()}) - state_dict.update({f"text_encoder_2.{k}": v for k, v in text_encoder_2_lora_layers.state_dict()}) + state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) - self.write_lora_weights( + self.write_lora_layers( state_dict=state_dict, save_directory=save_directory, is_main_process=is_main_process, From 989e54d0a705c991392d7bde21393e74e780c31d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 21 Jul 2023 20:18:23 +0530 Subject: [PATCH 41/41] Apply suggestions from code review Co-authored-by: Patrick von Platen --- examples/dreambooth/README_sdxl.md | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/dreambooth/README_sdxl.md b/examples/dreambooth/README_sdxl.md index 429df55a5d1d..b490de8e2f5f 100644 --- a/examples/dreambooth/README_sdxl.md +++ b/examples/dreambooth/README_sdxl.md @@ -174,6 +174,7 @@ Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. To ### Specifying a better VAE SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). + ## Notes In our experiments we found that SDXL yields very good initial results using the default settings of the script. We didn't explore further hyper-parameter tuning experiments, but we do encourage the community to explore this avenue further and share their results with us 🤗