diff --git a/examples/dreambooth/README_sdxl.md b/examples/dreambooth/README_sdxl.md index 51133c24bc1a..b490de8e2f5f 100644 --- a/examples/dreambooth/README_sdxl.md +++ b/examples/dreambooth/README_sdxl.md @@ -164,6 +164,17 @@ Here's a side-by-side comparison of the with and without Refiner pipeline output |---|---| | ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/sks_dog.png) | ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/refined_sks_dog.png) | +### Training with text encoder(s) + +Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind: + +* SDXL has two text encoders. So, we fine-tune both using LoRA. +* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory. + +### Specifying a better VAE + +SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). + ## Notes In our experiments we found that SDXL yields very good initial results using the default settings of the script. We didn't explore further hyper-parameter tuning experiments, but we do encourage the community to explore this avenue further and share their results with us 🤗 diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 81d87e9c49e2..a774ccb7d6a2 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -16,6 +16,7 @@ import argparse import gc import hashlib +import itertools import logging import math import os @@ -45,11 +46,11 @@ from diffusers import ( AutoencoderKL, DDPMScheduler, - DiffusionPipeline, DPMSolverMultistepScheduler, + StableDiffusionXLPipeline, UNet2DConditionModel, ) -from diffusers.loaders import LoraLoaderMixin +from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0 from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available @@ -63,12 +64,7 @@ def save_model_card( - repo_id: str, - images=None, - base_model=str, - train_text_encoder=False, - prompt=str, - repo_folder=None, + repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None ): img_str = "" for i, image in enumerate(images): @@ -96,6 +92,8 @@ def save_model_card( {img_str} LoRA for the text encoder was enabled: {train_text_encoder}. + +Special VAE used for training: {vae_path}. """ with open(os.path.join(repo_folder, "README.md"), "w") as f: f.write(yaml + model_card) @@ -130,6 +128,12 @@ def parse_args(input_args=None): required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", + ) parser.add_argument( "--revision", type=str, @@ -420,38 +424,25 @@ def parse_args(input_args=None): if args.class_prompt is not None: warnings.warn("You need not use --class_prompt without --with_prior_preservation.") - if args.train_text_encoder and args.pre_compute_text_embeddings: - raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") - return args class DreamBoothDataset(Dataset): """ A dataset to prepare the instance and class images with the prompts for fine-tuning the model. - It pre-processes the images and the tokenizes prompts. + It pre-processes the images. """ def __init__( self, instance_data_root, - instance_prompt, class_data_root=None, - class_prompt=None, class_num=None, size=1024, center_crop=False, - instance_prompt_hidden_states=None, - class_prompt_hidden_states=None, - instance_unet_added_conditions=None, - class_unet_added_conditions=None, ): self.size = size self.center_crop = center_crop - self.instance_prompt_hidden_states = instance_prompt_hidden_states - self.class_prompt_hidden_states = class_prompt_hidden_states - self.instance_unet_added_conditions = instance_unet_added_conditions - self.class_unet_added_conditions = class_unet_added_conditions self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): @@ -459,7 +450,6 @@ def __init__( self.instance_images_path = list(Path(instance_data_root).iterdir()) self.num_instance_images = len(self.instance_images_path) - self.instance_prompt = instance_prompt self._length = self.num_instance_images if class_data_root is not None: @@ -471,7 +461,6 @@ def __init__( else: self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) - self.class_prompt = class_prompt else: self.class_data_root = None @@ -496,9 +485,6 @@ def __getitem__(self, index): instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) - example["instance_prompt_ids"] = self.instance_prompt_hidden_states - example["instance_added_cond_kwargs"] = self.instance_unet_added_conditions - if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) class_image = exif_transpose(class_image) @@ -506,49 +492,22 @@ def __getitem__(self, index): if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) - example["class_prompt_ids"] = self.class_prompt_hidden_states - example["class_added_cond_kwargs"] = self.class_unet_added_conditions return example def collate_fn(examples, with_prior_preservation=False): - has_attention_mask = "instance_attention_mask" in examples[0] - - input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] - add_text_embeds = [example["instance_added_cond_kwargs"]["text_embeds"] for example in examples] - add_time_ids = [example["instance_added_cond_kwargs"]["time_ids"] for example in examples] - if has_attention_mask: - attention_mask = [example["instance_attention_mask"] for example in examples] # Concat class and instance examples for prior preservation. # We do this to avoid doing two forward passes. if with_prior_preservation: - input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] - add_text_embeds += [example["class_added_cond_kwargs"]["text_embeds"] for example in examples] - add_time_ids += [example["class_added_cond_kwargs"]["time_ids"] for example in examples] - - if has_attention_mask: - attention_mask += [example["class_attention_mask"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - input_ids = torch.cat(input_ids, dim=0) - add_text_embeds = torch.cat(add_text_embeds, dim=0) - add_time_ids = torch.cat(add_time_ids, dim=0) - - batch = { - "input_ids": input_ids, - "pixel_values": pixel_values, - "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids}, - } - - if has_attention_mask: - batch["attention_mask"] = attention_mask - + batch = {"pixel_values": pixel_values} return batch @@ -569,27 +528,29 @@ def __getitem__(self, index): return example +def tokenize_prompt(tokenizer, prompt): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + return text_input_ids + + # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt -def encode_prompt(text_encoders, tokenizers, prompt): +def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): prompt_embeds_list = [] - for tokenizer, text_encoder in zip(tokenizers, text_encoders): - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) + for i, text_encoder in enumerate(text_encoders): + if tokenizers is not None: + tokenizer = tokenizers[i] + text_input_ids = tokenize_prompt(tokenizer, prompt) + else: + assert text_input_ids_list is not None + text_input_ids = text_input_ids_list[i] prompt_embeds = text_encoder( text_input_ids.to(text_encoder.device), @@ -641,9 +602,6 @@ def main(args): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") import wandb - if args.train_text_encoder: - raise NotImplementedError("Text encoder training not yet supported.") - # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -677,7 +635,7 @@ def main(args): torch_dtype = torch.float16 elif args.prior_generation_precision == "bf16": torch_dtype = torch.bfloat16 - pipeline = DiffusionPipeline.from_pretrained( + pipeline = StableDiffusionXLPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None, @@ -742,7 +700,14 @@ def main(args): text_encoder_two = text_encoder_cls_two.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision ) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision + ) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) @@ -764,7 +729,10 @@ def main(args): # Move unet, vae and text_encoder to device and cast to weight_dtype # The VAE is in float32 to avoid NaN losses. unet.to(accelerator.device, dtype=weight_dtype) - vae.to(accelerator.device, dtype=torch.float32) + if args.pretrained_vae_model_name_or_path is None: + vae.to(accelerator.device, dtype=torch.float32) + else: + vae.to(accelerator.device, dtype=weight_dtype) text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) @@ -804,42 +772,66 @@ def main(args): unet_lora_parameters.extend(module.parameters()) unet.set_attn_processor(unet_lora_attn_procs) - # unet_lora_layers = AttnProcsLayers(unet.attn_processors) + + # The text encoder comes from 🤗 transformers, so we cannot directly modify it. + # So, instead, we monkey-patch the forward calls of its attention-blocks. + if args.train_text_encoder: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(text_encoder_one, dtype=torch.float32) + text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(text_encoder_two, dtype=torch.float32) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): # there are only two options here. Either are just the unet attn processor layers # or there are the unet and text encoder atten layers unet_lora_layers_to_save = None + text_encoder_one_lora_layers_to_save = None + text_encoder_two_lora_layers_to_save = None for model in models: - unet_lora_layers_to_save = unet_attn_processors_state_dict(model) + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_lora_layers_to_save = unet_attn_processors_state_dict(model) + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model) + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): + text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again weights.pop() - LoraLoaderMixin.save_lora_weights( + StableDiffusionXLPipeline.save_lora_weights( output_dir, unet_lora_layers=unet_lora_layers_to_save, - text_encoder_lora_layers=None, + text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, + text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, ) def load_model_hook(models, input_dir): unet_ = None - text_encoder_ = None + text_encoder_one_ = None + text_encoder_two_ = None while len(models) > 0: model = models.pop() if isinstance(model, type(accelerator.unwrap_model(unet))): unet_ = model + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + text_encoder_one_ = model + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): + text_encoder_two_ = model else: raise ValueError(f"unexpected save model: {model.__class__}") lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir) LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_) LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_ + lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_one_ + ) + LoraLoaderMixin.load_lora_into_text_encoder( + lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_two_ ) accelerator.register_save_state_pre_hook(save_model_hook) @@ -869,7 +861,11 @@ def load_model_hook(models, input_dir): optimizer_class = torch.optim.AdamW # Optimizer creation - params_to_optimize = unet_lora_parameters + params_to_optimize = ( + itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two) + if args.train_text_encoder + else unet_lora_parameters + ) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, @@ -878,62 +874,81 @@ def load_model_hook(models, input_dir): eps=args.adam_epsilon, ) - # We ALWAYS pre-compute the additional condition embeddings needed for SDXL - # UNet as the model is already big and it uses two text encoders. - # TODO: when we add support for text encoder training, will reivist. - tokenizers = [tokenizer_one, tokenizer_two] - text_encoders = [text_encoder_one, text_encoder_two] + # Computes additional embeddings/ids required by the SDXL UNet. + # regular text emebddings (when `train_text_encoder` is not True) + # pooled text embeddings + # time ids - # Here, we compute not just the text embeddings but also the additional embeddings - # needed for the SD XL UNet to operate. - def compute_embeddings(prompt, text_encoders, tokenizers): + def compute_time_ids(): + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids original_size = (args.resolution, args.resolution) target_size = (args.resolution, args.resolution) crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) + return add_time_ids + + if not args.train_text_encoder: + tokenizers = [tokenizer_one, tokenizer_two] + text_encoders = [text_encoder_one, text_encoder_two] + + def compute_text_embeddings(prompt, text_encoders, tokenizers): + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt) + prompt_embeds = prompt_embeds.to(accelerator.device) + pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + return prompt_embeds, pooled_prompt_embeds + + # Handle instance prompt. + instance_time_ids = compute_time_ids() + if not args.train_text_encoder: + instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( + args.instance_prompt, text_encoders, tokenizers + ) - with torch.no_grad(): - prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt) - add_text_embeds = pooled_prompt_embeds - - # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_time_ids = torch.tensor([add_time_ids]) - - prompt_embeds = prompt_embeds.to(accelerator.device) - add_text_embeds = add_text_embeds.to(accelerator.device) - add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype) - unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - - return prompt_embeds, unet_added_cond_kwargs - - instance_prompt_hidden_states, instance_unet_added_conditions = compute_embeddings( - args.instance_prompt, text_encoders, tokenizers - ) - - class_prompt_hidden_states, class_unet_added_conditions = None, None + # Handle class prompt for prior-preservation. if args.with_prior_preservation: - class_prompt_hidden_states, class_unet_added_conditions = compute_embeddings( - args.class_prompt, text_encoders, tokenizers - ) + class_time_ids = compute_time_ids() + if not args.train_text_encoder: + class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( + args.class_prompt, text_encoders, tokenizers + ) - del tokenizers, text_encoders + # Clear the memory here. + if not args.train_text_encoder: + del tokenizers, text_encoders + gc.collect() + torch.cuda.empty_cache() - gc.collect() - torch.cuda.empty_cache() + # Pack the statically computed variables appropriately. This is so that we don't + # have to pass them to the dataloader. + add_time_ids = instance_time_ids + if args.with_prior_preservation: + add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0) + + if not args.train_text_encoder: + prompt_embeds = instance_prompt_hidden_states + unet_add_text_embeds = instance_pooled_prompt_embeds + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) + else: + tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt) + tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt) + if args.with_prior_preservation: + class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt) + class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt) + tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) + tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, - instance_prompt=args.instance_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None, - class_prompt=args.class_prompt, class_num=args.num_class_images, size=args.resolution, center_crop=args.center_crop, - instance_prompt_hidden_states=instance_prompt_hidden_states, - class_prompt_hidden_states=class_prompt_hidden_states, - instance_unet_added_conditions=instance_unet_added_conditions, - class_unet_added_conditions=class_unet_added_conditions, ) train_dataloader = torch.utils.data.DataLoader( @@ -954,16 +969,21 @@ def compute_embeddings(prompt, text_encoders, tokenizers): lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, num_cycles=args.lr_num_cycles, power=args.lr_power, ) # Prepare everything with our `accelerator`. - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, optimizer, train_dataloader, lr_scheduler - ) + if args.train_text_encoder: + unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1022,6 +1042,9 @@ def compute_embeddings(prompt, text_encoders, tokenizers): for epoch in range(first_epoch, args.num_train_epochs): unet.train() + if args.train_text_encoder: + text_encoder_one.train() + text_encoder_two.train() for step, batch in enumerate(train_dataloader): # Skip steps until we reach the resumed step if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: @@ -1030,12 +1053,16 @@ def compute_embeddings(prompt, text_encoders, tokenizers): continue with accelerator.accumulate(unet): - # pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + if args.pretrained_vae_model_name_or_path is None: + pixel_values = batch["pixel_values"] + else: + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) # Convert images to latent space - model_input = vae.encode(batch["pixel_values"]).latent_dist.sample() + model_input = vae.encode(pixel_values).latent_dist.sample() model_input = model_input * vae.config.scaling_factor - model_input = model_input.to(weight_dtype) + if args.pretrained_vae_model_name_or_path is None: + model_input = model_input.to(weight_dtype) # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) @@ -1051,9 +1078,30 @@ def compute_embeddings(prompt, text_encoders, tokenizers): noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) # Predict the noise residual - model_pred = unet( - noisy_model_input, timesteps, batch["input_ids"], added_cond_kwargs=batch["unet_added_conditions"] - ).sample + if not args.train_text_encoder: + unet_added_conditions = { + "time_ids": add_time_ids.repeat(bsz, 1), + "text_embeds": unet_add_text_embeds.repeat(bsz, 1), + } + model_pred = unet( + noisy_model_input, + timesteps, + prompt_embeds.repeat(bsz, 1, 1), + added_cond_kwargs=unet_added_conditions, + ).sample + else: + unet_added_conditions = {"time_ids": add_time_ids.repeat(bsz, 1)} + prompt_embeds, pooled_prompt_embeds = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=None, + prompt=None, + text_input_ids_list=[tokens_one, tokens_two], + ) + unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(bsz, 1)}) + prompt_embeds = prompt_embeds.repeat(bsz, 1, 1) + model_pred = unet( + noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions + ).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": @@ -1081,7 +1129,11 @@ def compute_embeddings(prompt, text_encoders, tokenizers): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = unet_lora_parameters + params_to_clip = ( + itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two) + if args.train_text_encoder + else unet_lora_parameters + ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() @@ -1132,8 +1184,22 @@ def compute_embeddings(prompt, text_encoders, tokenizers): f" {args.validation_prompt}." ) # create pipeline - pipeline = DiffusionPipeline.from_pretrained( + if not args.train_text_encoder: + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + ) + pipeline = StableDiffusionXLPipeline.from_pretrained( args.pretrained_model_name_or_path, + vae=vae, + text_encoder=accelerator.unwrap_model(text_encoder_one) + if args.train_text_encoder + else text_encoder_one, + text_encoder_2=accelerator.unwrap_model(text_encoder_two) + if args.train_text_encoder + else text_encoder_two, unet=accelerator.unwrap_model(unet), revision=args.revision, torch_dtype=weight_dtype, @@ -1161,9 +1227,11 @@ def compute_embeddings(prompt, text_encoders, tokenizers): generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None pipeline_args = {"prompt": args.validation_prompt} - images = [ - pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images) - ] + with torch.cuda.amp.autocast(): + images = [ + pipeline(**pipeline_args, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] for tracker in accelerator.trackers: if tracker.name == "tensorboard": @@ -1189,16 +1257,32 @@ def compute_embeddings(prompt, text_encoders, tokenizers): unet = unet.to(torch.float32) unet_lora_layers = unet_attn_processors_state_dict(unet) - LoraLoaderMixin.save_lora_weights( + if args.train_text_encoder: + text_encoder_one = accelerator.unwrap_model(text_encoder_one) + text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32)) + text_encoder_two = accelerator.unwrap_model(text_encoder_two) + text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32)) + else: + text_encoder_lora_layers = None + text_encoder_2_lora_layers = None + + StableDiffusionXLPipeline.save_lora_weights( save_directory=args.output_dir, unet_lora_layers=unet_lora_layers, - text_encoder_lora_layers=None, + text_encoder_lora_layers=text_encoder_lora_layers, + text_encoder_2_lora_layers=text_encoder_2_lora_layers, ) # Final inference # Load previous pipeline - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype ) # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it @@ -1250,6 +1334,7 @@ def compute_embeddings(prompt, text_encoders, tokenizers): train_text_encoder=args.train_text_encoder, prompt=args.instance_prompt, repo_folder=args.output_dir, + vae_path=args.pretrained_vae_model_name_or_path, ) upload_folder( repo_id=repo_id, diff --git a/examples/test_examples.py b/examples/test_examples.py index cc3b3fbf7478..646e4ec74276 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -385,6 +385,42 @@ def test_dreambooth_lora_sdxl(self): starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys()) self.assertTrue(starts_with_unet) + def test_dreambooth_lora_sdxl_with_text_encoder(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --train_text_encoder + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names. + keys = lora_state_dict.keys() + starts_with_unet = all( + k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys + ) + self.assertTrue(starts_with_unet) + def test_custom_diffusion(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 8ce5989b5f49..4e6fea7c5971 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -58,7 +58,7 @@ import safetensors if is_transformers_available(): - from transformers import CLIPTextModel, PreTrainedModel, PreTrainedTokenizer + from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer if is_accelerate_available(): from accelerate import init_empty_weights @@ -107,7 +107,7 @@ def forward(self, input): def text_encoder_attn_modules(text_encoder): attn_modules = [] - if isinstance(text_encoder, CLIPTextModel): + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): for i, layer in enumerate(text_encoder.text_model.encoder.layers): name = f"text_model.encoder.layers.{i}.self_attn" mod = layer.self_attn @@ -1015,18 +1015,20 @@ def load_lora_into_unet(cls, state_dict, network_alpha, unet): warnings.warn(warn_message) @classmethod - def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0): + def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, prefix=None, lora_scale=1.0): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` Parameters: state_dict (`dict`): - A standard state dict containing the lora layer parameters. The key shoult be prefixed with an + A standard state dict containing the lora layer parameters. The key should be prefixed with an additional `text_encoder` to distinguish between unet lora layers. network_alpha (`float`): See `LoRALinearLayer` for more details. text_encoder (`CLIPTextModel`): The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. lora_scale (`float`): How much to scale the output of the lora linear layer before it is added with the output of the regular lora layer. @@ -1036,14 +1038,16 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) - if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): + prefix = cls.text_encoder_name if prefix is None else prefix + + if any(cls.text_encoder_name in key for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. - text_encoder_keys = [k for k in keys if k.startswith(cls.text_encoder_name)] + text_encoder_keys = [k for k in keys if k.startswith(prefix)] text_encoder_lora_state_dict = { - k.replace(f"{cls.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys + k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys } if len(text_encoder_lora_state_dict) > 0: - logger.info(f"Loading {cls.text_encoder_name}.") + logger.info(f"Loading {prefix}.") if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()): # Convert from the old naming convention to the new naming convention. @@ -1183,23 +1187,10 @@ def save_lora_weights( replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. """ - if os.path.isfile(save_directory): - logger.error(f"Provided path ({save_directory}) should be a directory, not a file") - return - - if save_function is None: - if safe_serialization: - - def save_function(weights, filename): - return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) - - else: - save_function = torch.save - - os.makedirs(save_directory, exist_ok=True) - # Create a flat dictionary. state_dict = {} + + # Populate the dictionary. if unet_lora_layers is not None: weights = ( unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers @@ -1221,6 +1212,38 @@ def save_function(weights, filename): state_dict.update(text_encoder_lora_state_dict) # Save the model + self.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def write_lora_layers( + state_dict: Dict[str, torch.Tensor], + save_directory: str, + is_main_process: bool, + weight_name: str, + save_function: Callable, + safe_serialization: bool, + ): + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + if save_function is None: + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + + else: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + if weight_name is None: if safe_serialization: weight_name = LORA_WEIGHT_NAME_SAFE diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 9863c663910f..0c338a6540dc 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -809,3 +810,66 @@ def __call__( return (image,) return StableDiffusionXLPipelineOutput(images=image) + + # Overrride to properly handle the loading and unloading of the additional text encoder. + def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet) + + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alpha=network_alpha, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + ) + + text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} + if len(text_encoder_2_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_2_state_dict, + network_alpha=network_alpha, + text_encoder=self.text_encoder_2, + prefix="text_encoder_2", + lora_scale=self.lora_scale, + ) + + @classmethod + def save_lora_weights( + self, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = False, + ): + state_dict = {} + + def pack_weights(layers, prefix): + layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers + layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} + return layers_state_dict + + state_dict.update(pack_weights(unet_lora_layers, "unet")) + + if text_encoder_lora_layers and text_encoder_2_lora_layers: + state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + + self.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def _remove_text_encoder_monkey_patch(self): + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 1396561367e0..531ca0472634 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -21,9 +21,16 @@ import torch.nn as nn import torch.nn.functional as F from huggingface_hub.repocard import RepoCard -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer - -from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + EulerDiscreteScheduler, + StableDiffusionPipeline, + StableDiffusionXLPipeline, + UNet2DConditionModel, +) from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin, PatchedLoraProjection, text_encoder_attn_modules from diffusers.models.attention_processor import ( Attention, @@ -399,7 +406,7 @@ def test_lora_unet_attn_processors(self): ) self.assertIsInstance(module.processor, attn_proc_class) - def test_unload_lora(self): + def test_unload_lora_sd(self): pipeline_components, lora_components = self.get_dummy_components() _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) sd_pipe = StableDiffusionPipeline(**pipeline_components) @@ -495,6 +502,175 @@ def test_lora_save_load_with_xformers(self): self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) +class SDXLLoraLoaderMixinTests(unittest.TestCase): + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + ) + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + steps_offset=1, + beta_schedule="scaled_linear", + timestep_spacing="leading", + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=32, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True) + + text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True) + + unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet) + text_encoder_one_lora_layers = create_text_encoder_lora_layers(text_encoder) + text_encoder_two_lora_layers = create_text_encoder_lora_layers(text_encoder_2) + + pipeline_components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + } + lora_components = { + "unet_lora_layers": unet_lora_layers, + "text_encoder_one_lora_layers": text_encoder_one_lora_layers, + "text_encoder_two_lora_layers": text_encoder_two_lora_layers, + "unet_lora_attn_procs": unet_lora_attn_procs, + } + return pipeline_components, lora_components + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + def test_lora_save_load(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs() + + original_images = sd_pipe(**pipeline_inputs).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + + def test_unload_lora_sdxl(self): + pipeline_components, lora_components = self.get_dummy_components() + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + + original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + # Emulate training. + set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Unload LoRA parameters. + sd_pipe.unload_lora_weights() + original_images_two = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + orig_image_slice_two = original_images_two[0, -3:, -3:, -1] + + assert not np.allclose( + orig_image_slice, lora_image_slice + ), "LoRA parameters should lead to a different image slice." + assert not np.allclose( + orig_image_slice_two, lora_image_slice + ), "LoRA parameters should lead to a different image slice." + assert np.allclose( + orig_image_slice, orig_image_slice_two, atol=1e-3 + ), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters." + + @slow @require_torch_gpu class LoraIntegrationTests(unittest.TestCase):