From ca7f220b8aaff1b222bb9625ad482420e73e0b7d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 13 Nov 2023 09:12:15 +0530 Subject: [PATCH 01/95] =?UTF-8?q?add:=20script=20to=20train=20lcm=20lora?= =?UTF-8?q?=20for=20sdxl=20with=20=F0=9F=A4=97=20datasets?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../train_lcm_distill_lora_sdxl.py | 1332 +++++++++++++++++ .../text_to_image/train_text_to_image_flax.py | 5 +- .../train_text_to_image_lora_sdxl.py | 5 +- 3 files changed, 1334 insertions(+), 8 deletions(-) create mode 100644 examples/consistency_distillation/train_lcm_distill_lora_sdxl.py diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py new file mode 100644 index 000000000000..c3b314e2ab36 --- /dev/null +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -0,0 +1,1332 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import functools +import gc +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import accelerate +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from peft import LoraConfig, get_peft_model, get_peft_model_state_dict +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + LCMScheduler, + StableDiffusionXLPipeline, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.24.0.dev0") + +logger = get_logger(__name__) + +MAX_SEQ_LENGTH = 77 +EMBEDDING_DIM = 2048 +POOLED_PROJECTION_DIM = 1280 + +DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + + +class DDIMSolver: + def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50): + # DDIM sampling parameters + step_ratio = timesteps // ddim_timesteps + + self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1 + self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] + self.ddim_alpha_cumprods_prev = np.asarray( + [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() + ) + # convert to torch tensors + self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() + self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) + self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev) + + def to(self, device): + self.ddim_timesteps = self.ddim_timesteps.to(device) + self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) + self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device) + return self + + def ddim_step(self, pred_x0, pred_noise, timestep_index): + alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape) + dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise + x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt + return x_prev + + +def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"): + kohya_ss_state_dict = {} + for peft_key, weight in get_peft_model_state_dict(module, adapter_name=adapter_name).items(): + kohya_key = peft_key.replace("base_model.model", prefix) + kohya_key = kohya_key.replace("lora_A", "lora_down") + kohya_key = kohya_key.replace("lora_B", "lora_up") + kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2) + kohya_ss_state_dict[kohya_key] = weight.to(dtype) + + # Set alpha parameter + if "lora_down" in kohya_key: + alpha_key = f'{kohya_key.split(".")[0]}.alpha' + kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype) + + return kohya_ss_state_dict + + +def log_validation(vae, unet, args, accelerator, weight_dtype, step): + logger.info("Running validation... ") + + unet = accelerator.unwrap_model(unet) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_teacher_model, + vae=vae, + scheduler=LCMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder="scheduler"), + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + lora_state_dict = get_module_kohya_state_dict(unet, "lora_unet", weight_dtype) + pipeline.load_lora_weights(lora_state_dict) + pipeline.fuse_lora() + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + validation_prompts = [ + "cute sundar pichai character", + "robotic cat with wings", + "a photo of yoda", + "a cute creature with blue eyes", + ] + + image_logs = [] + + for _, prompt in enumerate(validation_prompts): + images = [] + with torch.autocast("cuda", dtype=weight_dtype): + images = pipeline( + prompt=prompt, + num_inference_steps=4, + num_images_per_prompt=4, + generator=generator, + guidance_scale=0.0, + ).images + image_logs.append({"validation_prompt": prompt, "images": images}) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + formatted_images = [] + for image in images: + formatted_images.append(np.asarray(image)) + + formatted_images = np.stack(formatted_images) + + tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({"validation": formatted_images}) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + del pipeline + gc.collect() + torch.cuda.empty_cache() + + return image_logs + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +# From LCMScheduler.get_scalings_for_boundary_condition_discrete +def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): + c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) + c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 + return c_skip, c_out + + +# Compare LCMScheduler.step, Step 4 +def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): + if prediction_type == "epsilon": + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + pred_x_0 = (sample - sigmas * model_output) / alphas + elif prediction_type == "v_prediction": + pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output + else: + raise ValueError(f"Prediction type {prediction_type} currently not supported.") + + return pred_x_0 + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def tokenize_prompt(tokenizer, prompt): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + return text_input_ids + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + # ----------Model Checkpoint Loading Arguments---------- + parser.add_argument( + "--pretrained_teacher_model", + type=str, + default=None, + required=True, + help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--teacher_revision", + type=str, + default=None, + required=False, + help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained LDM model identifier from huggingface.co/models.", + ) + # ----------Training Arguments---------- + # ----General Training Arguments---- + parser.add_argument( + "--output_dir", + type=str, + default="lcm-xl-distilled", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + # ----Logging---- + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + # ----Checkpointing---- + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + # ----Image Processing---- + parser.add_argument( + "--train_shards_path_or_url", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--resolution", + type=int, + default=1024, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--use_fix_crop_and_size", + action="store_true", + help="Whether or not to use the fixed crop and size for the teacher model.", + default=False, + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + # ----Dataloader---- + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + # ----Batch Size and Training Steps---- + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + # ----Learning Rate---- + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + # ----Optimizer (Adam)---- + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + # ----Diffusion Training Arguments---- + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + # ----Latent Consistency Distillation (LCD) Specific Arguments---- + parser.add_argument( + "--w_min", + type=float, + default=3.0, + required=False, + help=( + "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG" + " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as" + " compared to the original paper." + ), + ) + parser.add_argument( + "--w_max", + type=float, + default=15.0, + required=False, + help=( + "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG" + " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as" + " compared to the original paper." + ), + ) + parser.add_argument( + "--num_ddim_timesteps", + type=int, + default=50, + help="The number of timesteps to use for DDIM sampling.", + ) + parser.add_argument( + "--loss_type", + type=str, + default="l2", + choices=["l2", "huber"], + help="The type of loss to use for the LCD loss.", + ) + parser.add_argument( + "--huber_c", + type=float, + default=0.001, + help="The huber loss parameter. Only used if `--loss_type=huber`.", + ) + parser.add_argument( + "--lora_rank", + type=int, + default=64, + help="The rank of the LoRA projection matrix.", + ) + # ----Mixed Precision---- + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cast_teacher_unet", + action="store_true", + help="Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.", + ) + # ----Training Optimizations---- + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + # ----Distributed Training---- + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + # ----------Validation Arguments---------- + parser.add_argument( + "--validation_steps", + type=int, + default=200, + help="Run validation every X steps.", + ) + # ----------Huggingface Hub Arguments----------- + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + # ----------Accelerate Arguments---------- + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + return args + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True): + prompt_embeds_list = [] + + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + token=args.hub_token, + private=True, + ).repo_id + + # 1. Create the noise scheduler and the desired noise schedule. + noise_scheduler = DDPMScheduler.from_pretrained( + args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision + ) + + # The scheduler calculates the alpha and sigma schedule for us + alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) + sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + solver = DDIMSolver( + noise_scheduler.alphas_cumprod.numpy(), + timesteps=noise_scheduler.config.num_train_timesteps, + ddim_timesteps=args.num_ddim_timesteps, + ) + + # 2. Load tokenizers from SD-XL checkpoint. + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_teacher_model, subfolder="tokenizer_2", revision=args.teacher_revision, use_fast=False + ) + + # 3. Load text encoders from SD-XL checkpoint. + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_teacher_model, args.teacher_revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_teacher_model, args.teacher_revision, subfolder="text_encoder_2" + ) + + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_teacher_model, subfolder="text_encoder_2", revision=args.teacher_revision + ) + + # 4. Load VAE from SD-XL checkpoint (or more stable VAE) + vae_path = ( + args.pretrained_teacher_model + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.teacher_revision, + ) + + # 5. Load teacher U-Net from SD-XL checkpoint + teacher_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision + ) + + # 6. Freeze teacher vae, text_encoders, and teacher_unet + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + teacher_unet.requires_grad_(False) + + # 7. Create online (`unet`) student U-Net. + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision + ) + unet.train() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if accelerator.unwrap_model(unet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" + ) + + # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. + lora_config = LoraConfig( + r=args.lora_rank, + target_modules=[ + "to_q", + "to_k", + "to_v", + "to_out.0", + "proj_in", + "proj_out", + "ff.net.0.proj", + "ff.net.2", + "conv1", + "conv2", + "conv_shortcut", + "downsamplers.0.conv", + "upsamplers.0.conv", + "time_emb_proj", + ], + ) + unet = get_peft_model(unet, lora_config) + + # 9. Handle mixed precision and device placement + # For mixed precision training we cast all non-trainable weigths to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + vae.to(accelerator.device) + if args.pretrained_vae_model_name_or_path is not None: + vae.to(dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + # Move teacher_unet to device, optionally cast to weight_dtype + teacher_unet.to(accelerator.device) + if args.cast_teacher_unet: + teacher_unet.to(dtype=weight_dtype) + + # Also move the alpha and sigma noise schedules to accelerator.device. + alpha_schedule = alpha_schedule.to(accelerator.device) + sigma_schedule = sigma_schedule.to(accelerator.device) + solver = solver.to(accelerator.device) + + # 10. Handle saving and loading of checkpoints + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + unet_ = accelerator.unwrap_model(unet) + lora_state_dict = get_peft_model_state_dict(unet_, adapter_name="default") + StableDiffusionXLPipeline.save_lora_weights(os.path.join(output_dir, "unet_lora"), lora_state_dict) + # save weights in peft format to be able to load them back + unet_.save_pretrained(output_dir) + + for _, model in enumerate(models): + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + # load the LoRA into the model + unet_ = accelerator.unwrap_model(unet) + unet_.load_adapter(input_dir, "default", is_trainable=True) + + for _ in range(len(models)): + # pop models so that they are not loaded again + models.pop() + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # 11. Enable optimizations + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + teacher_unet.enable_xformers_memory_efficient_attention() + # target_unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # 12. Optimizer creation + optimizer = optimizer_class( + unet.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # 13. Dataset creation and data processing + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + # image aug + original_sizes = [] + all_images = [] + crop_top_lefts = [] + for image in images: + original_sizes.append((image.height, image.width)) + image = train_resize(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) + if args.random_flip and random.random() < 0.5: + # flip + x1 = image.width - x1 + image = train_flip(image) + crop_top_left = (y1, x1) + crop_top_lefts.append(crop_top_left) + image = train_transforms(image) + all_images.append(image) + + examples["original_sizes"] = original_sizes + examples["crop_top_lefts"] = crop_top_lefts + examples["pixel_values"] = all_images + examples["captions"] = list(examples[caption_column]) + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + original_sizes = [example["original_sizes"] for example in examples] + crop_top_lefts = [example["crop_top_lefts"] for example in examples] + captions = torch.stack([example["captions"] for example in examples]) + + return { + "pixel_values": pixel_values, + "captions": captions, + "original_sizes": original_sizes, + "crop_top_lefts": crop_top_lefts, + } + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # 14. Embeddings for the UNet. + def compute_embeddings( + prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True + ): + target_size = (args.resolution, args.resolution) + original_sizes = list(map(list, zip(*original_sizes))) + crops_coords_top_left = list(map(list, zip(*crop_coords))) + + original_sizes = torch.tensor(original_sizes, dtype=torch.long) + crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long) + + prompt_embeds, pooled_prompt_embeds = encode_prompt( + prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train + ) + add_text_embeds = pooled_prompt_embeds + + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + add_time_ids = list(target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) + add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1) + add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype) + + prompt_embeds = prompt_embeds.to(accelerator.device) + add_text_embeds = add_text_embeds.to(accelerator.device) + unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} + + text_encoders = [text_encoder_one, text_encoder_two] + tokenizers = [tokenizer_one, tokenizer_two] + + compute_embeddings_fn = functools.partial( + compute_embeddings, + proportion_empty_prompts=0, + text_encoders=text_encoders, + tokenizers=tokenizers, + ) + + # 14. LR Scheduler creation + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + # 15. Prepare for training + # Prepare everything with our `accelerator`. + unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # Create uncond embeds for classifier free guidance + uncond_prompt_embeds = torch.zeros(args.train_batch_size, MAX_SEQ_LENGTH, EMBEDDING_DIM).to(accelerator.device) + uncond_pooled_prompt_embeds = torch.zeros(args.train_batch_size, POOLED_PROJECTION_DIM).to(accelerator.device) + + # 16. Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + image, text, orig_size, crop_coords = batch + + image = image.to(accelerator.device, non_blocking=True) + encoded_text = compute_embeddings_fn(text, orig_size, crop_coords) + + if args.pretrained_vae_model_name_or_path is not None: + pixel_values = image.to(dtype=weight_dtype) + if vae.dtype != weight_dtype: + vae.to(dtype=weight_dtype) + else: + pixel_values = image + + # encode pixel values with batch size of at most 8 + latents = [] + for i in range(0, pixel_values.shape[0], 8): + latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample()) + latents = torch.cat(latents, dim=0) + + latents = latents * vae.config.scaling_factor + if args.pretrained_vae_model_name_or_path is None: + latents = latents.to(weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias. + topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps + index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() + start_timesteps = solver.ddim_timesteps[index] + timesteps = start_timesteps - topk + timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) + + # Get boundary scalings for start_timesteps and (end) timesteps. + c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) + c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] + c_skip, c_out = scalings_for_boundary_conditions(timesteps) + c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) + + # Sample a random guidance scale w from U[w_min, w_max] and embed it + w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min + w = w.reshape(bsz, 1, 1, 1) + w = w.to(device=latents.device, dtype=latents.dtype) + + # Prepare prompt embeds and unet_added_conditions + prompt_embeds = encoded_text.pop("prompt_embeds") + + # Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k} + noise_pred = unet( + noisy_model_input, + start_timesteps, + timestep_cond=None, + encoder_hidden_states=prompt_embeds.float(), + added_cond_kwargs=encoded_text, + ).sample + + pred_x_0 = predicted_origin( + noise_pred, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 + + # Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after + # noisy_latents with both the conditioning embedding c and unconditional embedding 0 + # Get teacher model prediction on noisy_latents and conditional embedding + with torch.no_grad(): + with torch.autocast("cuda"): + cond_teacher_output = teacher_unet( + noisy_model_input.to(weight_dtype), + start_timesteps, + encoder_hidden_states=prompt_embeds.to(weight_dtype), + added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, + ).sample + cond_pred_x0 = predicted_origin( + cond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + + # Get teacher model prediction on noisy_latents and unconditional embedding + uncond_added_conditions = copy.deepcopy(encoded_text) + uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds + uncond_teacher_output = teacher_unet( + noisy_model_input.to(weight_dtype), + start_timesteps, + encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), + added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()}, + ).sample + uncond_pred_x0 = predicted_origin( + uncond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + + # Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) + pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) + pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) + x_prev = solver.ddim_step(pred_x0, pred_noise, index) + + # Get target LCM prediction on x_prev, w, c, t_n + with torch.no_grad(): + with torch.autocast("cuda", enabled=True, dtype=weight_dtype): + target_noise_pred = unet( + x_prev.float(), + timesteps, + timestep_cond=None, + encoder_hidden_states=prompt_embeds.float(), + added_cond_kwargs=encoded_text, + ).sample + pred_x_0 = predicted_origin( + target_noise_pred, + timesteps, + x_prev, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + target = c_skip * x_prev + c_out * pred_x_0 + + # Calculate loss + if args.loss_type == "l2": + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + elif args.loss_type == "huber": + loss = torch.mean( + torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c + ) + + # Backpropagate on the online student model (`unet`) + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if global_step % args.validation_steps == 0: + log_validation(vae, unet, args, accelerator, weight_dtype, global_step) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + unet.save_pretrained(args.output_dir) + lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default") + StableDiffusionXLPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict) + + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index 9ebe34555310..e62d03c730b1 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -272,10 +272,7 @@ def main(): if args.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = load_dataset( - args.dataset_name, - args.dataset_config_name, - cache_dir=args.cache_dir, - data_dir=args.train_data_dir + args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir ) else: data_files = {} diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 1a6ef0c856db..b69940603128 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -765,10 +765,7 @@ def load_model_hook(models, input_dir): if args.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = load_dataset( - args.dataset_name, - args.dataset_config_name, - cache_dir=args.cache_dir, - data_dir=args.train_data_dir + args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir ) else: data_files = {} From 88efd154bd1b22efb233718af09f4629db6986ee Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 13 Nov 2023 09:24:55 +0530 Subject: [PATCH 02/95] suit up the args. --- .../train_lcm_distill_lora_sdxl.py | 37 ++++++++++++++----- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index c3b314e2ab36..b7ef27f84678 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -355,7 +355,7 @@ def parse_args(): ) # ----Image Processing---- parser.add_argument( - "--train_shards_path_or_url", + "--dataset_name", type=str, default=None, help=( @@ -364,6 +364,31 @@ def parse_args(): " or to a folder containing files that 🤗 Datasets can understand." ), ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) parser.add_argument( "--resolution", type=int, @@ -373,12 +398,6 @@ def parse_args(): " resolution" ), ) - parser.add_argument( - "--use_fix_crop_and_size", - action="store_true", - help="Whether or not to use the fixed crop and size for the teacher model.", - default=False, - ) parser.add_argument( "--center_crop", default=False, @@ -466,7 +485,7 @@ def parse_args(): parser.add_argument( "--proportion_empty_prompts", type=float, - default=0, + default=0.0, help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", ) # ----Latent Consistency Distillation (LCD) Specific Arguments---- @@ -1021,7 +1040,7 @@ def compute_embeddings( compute_embeddings_fn = functools.partial( compute_embeddings, - proportion_empty_prompts=0, + proportion_empty_prompts=args.proportion_empty_prompts, text_encoders=text_encoders, tokenizers=tokenizers, ) From 9e49fd2eaf14d6cef15f5e2c2c94a5cfbf2654e8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 13 Nov 2023 09:59:28 +0530 Subject: [PATCH 03/95] remove comments. --- examples/consistency_distillation/train_lcm_distill_lora_sdxl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index b7ef27f84678..8274befbb053 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -862,7 +862,6 @@ def load_model_hook(models, input_dir): ) unet.enable_xformers_memory_efficient_attention() teacher_unet.enable_xformers_memory_efficient_attention() - # target_unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") From 728aa8a614289a42418a2cb8e1f62983e2d8b193 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 13 Nov 2023 10:11:57 +0530 Subject: [PATCH 04/95] fix num_update_steps --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 8274befbb053..bcc0d09fe88e 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1047,7 +1047,7 @@ def compute_embeddings( # 14. LR Scheduler creation # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True From bc8cfddf787d5ff5b01c8de4bd4a79c63f333570 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 13 Nov 2023 10:20:04 +0530 Subject: [PATCH 05/95] fix batch unmarshalling --- .../train_lcm_distill_lora_sdxl.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index bcc0d09fe88e..81fa4689179d 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1066,7 +1066,9 @@ def compute_embeddings( # 15. Prepare for training # Prepare everything with our `accelerator`. - unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) @@ -1136,7 +1138,12 @@ def compute_embeddings( for epoch in range(first_epoch, args.num_train_epochs): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): - image, text, orig_size, crop_coords = batch + image, text, orig_size, crop_coords = ( + batch["pixel_values"], + batch["captions"], + batch["original_sizes"], + batch["crop_top_lefts"], + ) image = image.to(accelerator.device, non_blocking=True) encoded_text = compute_embeddings_fn(text, orig_size, crop_coords) From 8c4d4b6a99e68a7e23058c0b4e0be7c5b76e2804 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 13 Nov 2023 10:24:06 +0530 Subject: [PATCH 06/95] fix num_update_steps_per_epoch --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 81fa4689179d..a656a3e87cfd 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1071,7 +1071,7 @@ def compute_embeddings( ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs From 6d2f740706213e5166e41cde9fbf3815a421ac3a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 13 Nov 2023 10:26:40 +0530 Subject: [PATCH 07/95] fix; dataloading. --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index a656a3e87cfd..95649a731dfb 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -987,7 +987,7 @@ def collate_fn(examples): pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() original_sizes = [example["original_sizes"] for example in examples] crop_top_lefts = [example["crop_top_lefts"] for example in examples] - captions = torch.stack([example["captions"] for example in examples]) + captions = [example["captions"] for example in examples] return { "pixel_values": pixel_values, From c7f28284f9a0efaed9a1504c1f5383cf992ea482 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 13 Nov 2023 10:38:28 +0530 Subject: [PATCH 08/95] fix microconditions. --- .../train_lcm_distill_lora_sdxl.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 95649a731dfb..6a0e35d12a2d 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1006,27 +1006,23 @@ def collate_fn(examples): ) # 14. Embeddings for the UNet. + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids def compute_embeddings( prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True ): - target_size = (args.resolution, args.resolution) - original_sizes = list(map(list, zip(*original_sizes))) - crops_coords_top_left = list(map(list, zip(*crop_coords))) - - original_sizes = torch.tensor(original_sizes, dtype=torch.long) - crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long) + def compute_time_ids(original_size, crops_coords_top_left): + target_size = (args.resolution, args.resolution) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) + return add_time_ids prompt_embeds, pooled_prompt_embeds = encode_prompt( prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train ) add_text_embeds = pooled_prompt_embeds - # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids - add_time_ids = list(target_size) - add_time_ids = torch.tensor([add_time_ids]) - add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) - add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1) - add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype) + add_time_ids = torch.cat([compute_time_ids(s, c) for s, c in zip(original_sizes, crop_coords)]) prompt_embeds = prompt_embeds.to(accelerator.device) add_text_embeds = add_text_embeds.to(accelerator.device) From df707545b7ad00e81d16356cb9546513c728d787 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 13 Nov 2023 11:18:30 +0530 Subject: [PATCH 09/95] unconditional predictions debug --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 6a0e35d12a2d..692d4e9a24f3 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1232,6 +1232,8 @@ def compute_time_ids(original_size, crops_coords_top_left): # Get teacher model prediction on noisy_latents and unconditional embedding uncond_added_conditions = copy.deepcopy(encoded_text) uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds + for k, v in uncond_added_conditions.items(): + print(k, v.shape) uncond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), start_timesteps, From dd93227ec3a5c7fd70591cd874a46a55fc10b45b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 13 Nov 2023 11:36:55 +0530 Subject: [PATCH 10/95] fix batch size. --- .../train_lcm_distill_lora_sdxl.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 692d4e9a24f3..f8e316f5a74f 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1079,10 +1079,6 @@ def compute_time_ids(original_size, crops_coords_top_left): tracker_config = dict(vars(args)) accelerator.init_trackers(args.tracker_project_name, config=tracker_config) - # Create uncond embeds for classifier free guidance - uncond_prompt_embeds = torch.zeros(args.train_batch_size, MAX_SEQ_LENGTH, EMBEDDING_DIM).to(accelerator.device) - uncond_pooled_prompt_embeds = torch.zeros(args.train_batch_size, POOLED_PROJECTION_DIM).to(accelerator.device) - # 16. Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1229,11 +1225,16 @@ def compute_time_ids(original_size, crops_coords_top_left): sigma_schedule, ) - # Get teacher model prediction on noisy_latents and unconditional embedding + # Create uncond embeds for classifier free guidance + uncond_prompt_embeds = torch.zeros( + cond_teacher_output.shape[0], MAX_SEQ_LENGTH, EMBEDDING_DIM + ).to(accelerator.device) + uncond_pooled_prompt_embeds = torch.zeros( + cond_teacher_output.shape[0], POOLED_PROJECTION_DIM + ).to(accelerator.device) uncond_added_conditions = copy.deepcopy(encoded_text) + # Get teacher model prediction on noisy_latents and unconditional embedding uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds - for k, v in uncond_added_conditions.items(): - print(k, v.shape) uncond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), start_timesteps, From 3d4b1da091e259a88fae7147f31cdaac45a7cdc6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 13 Nov 2023 12:55:06 +0530 Subject: [PATCH 11/95] no need to use use_auth_token --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index f8e316f5a74f..93d62ad1aa02 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -248,7 +248,7 @@ def import_model_class_from_model_name_or_path( pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" ): text_encoder_config = PretrainedConfig.from_pretrained( - pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True + pretrained_model_name_or_path, subfolder=subfolder, revision=revision ) model_class = text_encoder_config.architectures[0] From 7967247113ab6f7efcbe2000ddc11f105d1d7f0a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Nov 2023 18:10:11 +0530 Subject: [PATCH 12/95] Apply suggestions from code review Co-authored-by: Suraj Patil --- .../train_lcm_distill_lora_sdxl.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 93d62ad1aa02..bb032e6b020d 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -445,7 +445,7 @@ def parse_args(): parser.add_argument( "--learning_rate", type=float, - default=1e-4, + default=1e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -1190,7 +1190,6 @@ def compute_time_ids(original_size, crops_coords_top_left): noise_pred = unet( noisy_model_input, start_timesteps, - timestep_cond=None, encoder_hidden_states=prompt_embeds.float(), added_cond_kwargs=encoded_text, ).sample @@ -1209,7 +1208,7 @@ def compute_time_ids(original_size, crops_coords_top_left): # noisy_latents with both the conditioning embedding c and unconditional embedding 0 # Get teacher model prediction on noisy_latents and conditional embedding with torch.no_grad(): - with torch.autocast("cuda"): + with torch.autocast("cuda", dtype=weight_dtype): cond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), start_timesteps, @@ -1257,11 +1256,10 @@ def compute_time_ids(original_size, crops_coords_top_left): # Get target LCM prediction on x_prev, w, c, t_n with torch.no_grad(): - with torch.autocast("cuda", enabled=True, dtype=weight_dtype): + with torch.autocast("cuda", dtype=weight_dtype): target_noise_pred = unet( x_prev.float(), timesteps, - timestep_cond=None, encoder_hidden_states=prompt_embeds.float(), added_cond_kwargs=encoded_text, ).sample From 6b2e42f2b9a2ba8820626ae4fa50b3d30121f742 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 13 Nov 2023 18:14:12 +0530 Subject: [PATCH 13/95] make vae encoding batch size an arg --- .../train_lcm_distill_lora_sdxl.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index bb032e6b020d..5621a8b446f3 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -412,6 +412,12 @@ def parse_args(): action="store_true", help="whether to randomly flip images horizontally", ) + parser.add_argument( + "--encode_batch_size", + type=int, + default=8, + help="Batch size to use for VAE encoding of the images for efficient processing.", + ) # ----Dataloader---- parser.add_argument( "--dataloader_num_workers", @@ -1149,8 +1155,8 @@ def compute_time_ids(original_size, crops_coords_top_left): # encode pixel values with batch size of at most 8 latents = [] - for i in range(0, pixel_values.shape[0], 8): - latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample()) + for i in range(0, pixel_values.shape[0], args.encode_batch_size): + latents.append(vae.encode(pixel_values[i : i + args.encode_batch_size]).latent_dist.sample()) latents = torch.cat(latents, dim=0) latents = latents * vae.config.scaling_factor From d7f632e658b18b1193499d3dcbe86e94c4891b6b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 13 Nov 2023 18:42:17 +0530 Subject: [PATCH 14/95] final serialization in kohya --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 5621a8b446f3..a5dd2f2e37e5 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -39,6 +39,7 @@ from peft import LoraConfig, get_peft_model, get_peft_model_state_dict from torchvision import transforms from torchvision.transforms.functional import crop +from safetensors.torch import save_file from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig @@ -1341,8 +1342,8 @@ def compute_time_ids(original_size, crops_coords_top_left): if accelerator.is_main_process: unet = accelerator.unwrap_model(unet) unet.save_pretrained(args.output_dir) - lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default") - StableDiffusionXLPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict) + kohya_state_dict = get_module_kohya_state_dict(unet, "lora_unet", dtype=unet.dtype) + save_file(kohya_state_dict, os.path.join(args.output_dir, "pytorch_lora_weights.safetensors")) if args.push_to_hub: upload_folder( From e4edb31bfc8871d70c0a77469289d34d6464dcc6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 13 Nov 2023 18:51:49 +0530 Subject: [PATCH 15/95] style --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index a5dd2f2e37e5..42ba90837b85 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -37,9 +37,9 @@ from huggingface_hub import create_repo, upload_folder from packaging import version from peft import LoraConfig, get_peft_model, get_peft_model_state_dict +from safetensors.torch import save_file from torchvision import transforms from torchvision.transforms.functional import crop -from safetensors.torch import save_file from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig From 6aa2dd8cab1c96211f589577e3baa26b7692fc2f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 10:29:09 +0530 Subject: [PATCH 16/95] state dict rejigging --- .../train_lcm_distill_lora_sdxl.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 42ba90837b85..8941ecd88f30 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -52,7 +52,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler -from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils import check_min_version, is_wandb_available, convert_state_dict_to_diffusers from diffusers.utils.import_utils import is_xformers_available @@ -132,8 +132,9 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step): pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) - lora_state_dict = get_module_kohya_state_dict(unet, "lora_unet", weight_dtype) - pipeline.load_lora_weights(lora_state_dict) + peft_state_dict = get_peft_model_state_dict(unet, adapter_name="default") + diffusers_state_dict = convert_state_dict_to_diffusers(peft_state_dict) + pipeline.load_lora_weights(diffusers_state_dict) pipeline.fuse_lora() if args.enable_xformers_memory_efficient_attention: @@ -836,8 +837,6 @@ def main(args): def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: unet_ = accelerator.unwrap_model(unet) - lora_state_dict = get_peft_model_state_dict(unet_, adapter_name="default") - StableDiffusionXLPipeline.save_lora_weights(os.path.join(output_dir, "unet_lora"), lora_state_dict) # save weights in peft format to be able to load them back unet_.save_pretrained(output_dir) @@ -1341,9 +1340,12 @@ def compute_time_ids(original_size, crops_coords_top_left): accelerator.wait_for_everyone() if accelerator.is_main_process: unet = accelerator.unwrap_model(unet) - unet.save_pretrained(args.output_dir) - kohya_state_dict = get_module_kohya_state_dict(unet, "lora_unet", dtype=unet.dtype) - save_file(kohya_state_dict, os.path.join(args.output_dir, "pytorch_lora_weights.safetensors")) + # unet.save_pretrained(args.output_dir) + # kohya_state_dict = get_module_kohya_state_dict(unet, "lora_unet", dtype=unet.dtype) + # save_file(kohya_state_dict, os.path.join(args.output_dir, "pytorch_lora_weights.safetensors")) + peft_state_dict = get_peft_model_state_dict(unet, adapter_name="default") + diffusers_state_dict = convert_state_dict_to_diffusers(peft_state_dict) + save_file(diffusers_state_dict, os.path.join(args.output_dir, "pytorch_lora_weights.safetensors")) if args.push_to_hub: upload_folder( From 1fd33782820e625bf44687aa604eee68a57d63d4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 10:37:38 +0530 Subject: [PATCH 17/95] feat: no separate teacher unet. --- .../train_lcm_distill_lora_sdxl.py | 111 +++++++++--------- 1 file changed, 56 insertions(+), 55 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 8941ecd88f30..8b636053bd4a 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -52,7 +52,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler -from diffusers.utils import check_min_version, is_wandb_available, convert_state_dict_to_diffusers +from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -754,15 +754,15 @@ def main(args): ) # 5. Load teacher U-Net from SD-XL checkpoint - teacher_unet = UNet2DConditionModel.from_pretrained( - args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision - ) + # teacher_unet = UNet2DConditionModel.from_pretrained( + # args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision + # ) # 6. Freeze teacher vae, text_encoders, and teacher_unet vae.requires_grad_(False) text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) - teacher_unet.requires_grad_(False) + # teacher_unet.requires_grad_(False) # 7. Create online (`unet`) student U-Net. unet = UNet2DConditionModel.from_pretrained( @@ -821,9 +821,9 @@ def main(args): text_encoder_two.to(accelerator.device, dtype=weight_dtype) # Move teacher_unet to device, optionally cast to weight_dtype - teacher_unet.to(accelerator.device) - if args.cast_teacher_unet: - teacher_unet.to(dtype=weight_dtype) + # teacher_unet.to(accelerator.device) + # if args.cast_teacher_unet: + # teacher_unet.to(dtype=weight_dtype) # Also move the alpha and sigma noise schedules to accelerator.device. alpha_schedule = alpha_schedule.to(accelerator.device) @@ -867,7 +867,7 @@ def load_model_hook(models, input_dir): "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) unet.enable_xformers_memory_efficient_attention() - teacher_unet.enable_xformers_memory_efficient_attention() + # teacher_unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") @@ -1213,52 +1213,53 @@ def compute_time_ids(original_size, crops_coords_top_left): # Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after # noisy_latents with both the conditioning embedding c and unconditional embedding 0 # Get teacher model prediction on noisy_latents and conditional embedding - with torch.no_grad(): - with torch.autocast("cuda", dtype=weight_dtype): - cond_teacher_output = teacher_unet( - noisy_model_input.to(weight_dtype), - start_timesteps, - encoder_hidden_states=prompt_embeds.to(weight_dtype), - added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, - ).sample - cond_pred_x0 = predicted_origin( - cond_teacher_output, - start_timesteps, - noisy_model_input, - noise_scheduler.config.prediction_type, - alpha_schedule, - sigma_schedule, - ) - - # Create uncond embeds for classifier free guidance - uncond_prompt_embeds = torch.zeros( - cond_teacher_output.shape[0], MAX_SEQ_LENGTH, EMBEDDING_DIM - ).to(accelerator.device) - uncond_pooled_prompt_embeds = torch.zeros( - cond_teacher_output.shape[0], POOLED_PROJECTION_DIM - ).to(accelerator.device) - uncond_added_conditions = copy.deepcopy(encoded_text) - # Get teacher model prediction on noisy_latents and unconditional embedding - uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds - uncond_teacher_output = teacher_unet( - noisy_model_input.to(weight_dtype), - start_timesteps, - encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), - added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()}, - ).sample - uncond_pred_x0 = predicted_origin( - uncond_teacher_output, - start_timesteps, - noisy_model_input, - noise_scheduler.config.prediction_type, - alpha_schedule, - sigma_schedule, - ) - - # Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) - pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) - pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) - x_prev = solver.ddim_step(pred_x0, pred_noise, index) + # Notice that we're disabling the adapter layers within the `unet` and then it becomes a + # regular teacher. This way, we don't have to separately initialize a teacher UNet. + with torch.no_grad() and torch.autocast("cuda", dtype=weight_dtype) and unet.disable_adapter(): + cond_teacher_output = unet( + noisy_model_input.to(weight_dtype), + start_timesteps, + encoder_hidden_states=prompt_embeds.to(weight_dtype), + added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, + ).sample + cond_pred_x0 = predicted_origin( + cond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + + # Create uncond embeds for classifier free guidance + uncond_prompt_embeds = torch.zeros(cond_teacher_output.shape[0], MAX_SEQ_LENGTH, EMBEDDING_DIM).to( + accelerator.device + ) + uncond_pooled_prompt_embeds = torch.zeros(cond_teacher_output.shape[0], POOLED_PROJECTION_DIM).to( + accelerator.device + ) + uncond_added_conditions = copy.deepcopy(encoded_text) + # Get teacher model prediction on noisy_latents and unconditional embedding + uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds + uncond_teacher_output = unet( + noisy_model_input.to(weight_dtype), + start_timesteps, + encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), + added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()}, + ).sample + uncond_pred_x0 = predicted_origin( + uncond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + + # Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) + pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) + pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) + x_prev = solver.ddim_step(pred_x0, pred_noise, index) # Get target LCM prediction on x_prev, w, c, t_n with torch.no_grad(): From 41354149079716bd5afb1760085d1cc7729cbea0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 10:53:45 +0530 Subject: [PATCH 18/95] debug --- src/diffusers/utils/state_dict_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 777c611f7150..38821d45d0dc 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -210,6 +210,7 @@ def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys(): raise ValueError(f"Original type {original_type} is not supported") + print(f"*******From state_dict_utils.py: {original_type}*******") mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type] return convert_state_dict(state_dict, mapping) From 3b066d26576d2259ccfc309aad9ad2c5ef34cdc6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 11:04:02 +0530 Subject: [PATCH 19/95] fix state dict serialization --- .../train_lcm_distill_lora_sdxl.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 8b636053bd4a..76bea9be7464 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -37,7 +37,6 @@ from huggingface_hub import create_repo, upload_folder from packaging import version from peft import LoraConfig, get_peft_model, get_peft_model_state_dict -from safetensors.torch import save_file from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm @@ -134,6 +133,10 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step): peft_state_dict = get_peft_model_state_dict(unet, adapter_name="default") diffusers_state_dict = convert_state_dict_to_diffusers(peft_state_dict) + diffusers_state_dict = { + f"{pipeline.unet_name}.{module_name}": param for module_name, param in diffusers_state_dict.items() + } + pipeline.load_lora_weights(diffusers_state_dict) pipeline.fuse_lora() @@ -1346,7 +1349,7 @@ def compute_time_ids(original_size, crops_coords_top_left): # save_file(kohya_state_dict, os.path.join(args.output_dir, "pytorch_lora_weights.safetensors")) peft_state_dict = get_peft_model_state_dict(unet, adapter_name="default") diffusers_state_dict = convert_state_dict_to_diffusers(peft_state_dict) - save_file(diffusers_state_dict, os.path.join(args.output_dir, "pytorch_lora_weights.safetensors")) + StableDiffusionXLPipeline.save_lora_weights(args.output_dir, diffusers_state_dict) if args.push_to_hub: upload_folder( From fc5546fecc0bbcc89869a47b8b3c63e8d2c135b3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 11:12:15 +0530 Subject: [PATCH 20/95] debug --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 2 +- src/diffusers/utils/state_dict_utils.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 76bea9be7464..020c17829023 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -136,7 +136,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step): diffusers_state_dict = { f"{pipeline.unet_name}.{module_name}": param for module_name, param in diffusers_state_dict.items() } - + print(list(diffusers_state_dict.keys())) pipeline.load_lora_weights(diffusers_state_dict) pipeline.fuse_lora() diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 38821d45d0dc..777c611f7150 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -210,7 +210,6 @@ def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys(): raise ValueError(f"Original type {original_type} is not supported") - print(f"*******From state_dict_utils.py: {original_type}*******") mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type] return convert_state_dict(state_dict, mapping) From ba0d0f25d4ae156920439f9c8d95e899672e3f0c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 11:19:46 +0530 Subject: [PATCH 21/95] debug --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 2 +- src/diffusers/utils/state_dict_utils.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 020c17829023..388500191557 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -134,7 +134,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step): peft_state_dict = get_peft_model_state_dict(unet, adapter_name="default") diffusers_state_dict = convert_state_dict_to_diffusers(peft_state_dict) diffusers_state_dict = { - f"{pipeline.unet_name}.{module_name}": param for module_name, param in diffusers_state_dict.items() + f"{module_name.replace('base_model.model', pipeline.unet_name)}.{module_name}": param for module_name, param in diffusers_state_dict.items() } print(list(diffusers_state_dict.keys())) pipeline.load_lora_weights(diffusers_state_dict) diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 777c611f7150..2f09395fb195 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -137,6 +137,7 @@ def convert_state_dict(state_dict, mapping): k = k.replace(pattern, new_pattern) break converted_state_dict[k] = v + print("From state_dict utils", any("lora_A" in converted_state_dict)) return converted_state_dict From 35e30fbb575ab928edf9489d532da2a1c86ee742 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 11:21:17 +0530 Subject: [PATCH 22/95] debug --- src/diffusers/utils/state_dict_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 2f09395fb195..5f6d1df5dd57 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -137,7 +137,7 @@ def convert_state_dict(state_dict, mapping): k = k.replace(pattern, new_pattern) break converted_state_dict[k] = v - print("From state_dict utils", any("lora_A" in converted_state_dict)) + print("From state_dict utils", any("lora_A" in k for k in converted_state_dict)) return converted_state_dict From 53c13f7f8dd11e7326a93ec37583f076e1edd927 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 11:23:33 +0530 Subject: [PATCH 23/95] remove prints. --- examples/consistency_distillation/train_lcm_distill_lora_sdxl.py | 1 - src/diffusers/utils/state_dict_utils.py | 1 - 2 files changed, 2 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 388500191557..7c59a5c94d22 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -136,7 +136,6 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step): diffusers_state_dict = { f"{module_name.replace('base_model.model', pipeline.unet_name)}.{module_name}": param for module_name, param in diffusers_state_dict.items() } - print(list(diffusers_state_dict.keys())) pipeline.load_lora_weights(diffusers_state_dict) pipeline.fuse_lora() diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 5f6d1df5dd57..777c611f7150 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -137,7 +137,6 @@ def convert_state_dict(state_dict, mapping): k = k.replace(pattern, new_pattern) break converted_state_dict[k] = v - print("From state_dict utils", any("lora_A" in k for k in converted_state_dict)) return converted_state_dict From cff23edb234232544154e3aadb68455a861418cf Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 11:33:36 +0530 Subject: [PATCH 24/95] remove kohya utility and make style --- .../train_lcm_distill_lora_sdxl.py | 23 ++----------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 7c59a5c94d22..c915864b3ba7 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -100,23 +100,6 @@ def ddim_step(self, pred_x0, pred_noise, timestep_index): return x_prev -def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"): - kohya_ss_state_dict = {} - for peft_key, weight in get_peft_model_state_dict(module, adapter_name=adapter_name).items(): - kohya_key = peft_key.replace("base_model.model", prefix) - kohya_key = kohya_key.replace("lora_A", "lora_down") - kohya_key = kohya_key.replace("lora_B", "lora_up") - kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2) - kohya_ss_state_dict[kohya_key] = weight.to(dtype) - - # Set alpha parameter - if "lora_down" in kohya_key: - alpha_key = f'{kohya_key.split(".")[0]}.alpha' - kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype) - - return kohya_ss_state_dict - - def log_validation(vae, unet, args, accelerator, weight_dtype, step): logger.info("Running validation... ") @@ -134,7 +117,8 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step): peft_state_dict = get_peft_model_state_dict(unet, adapter_name="default") diffusers_state_dict = convert_state_dict_to_diffusers(peft_state_dict) diffusers_state_dict = { - f"{module_name.replace('base_model.model', pipeline.unet_name)}.{module_name}": param for module_name, param in diffusers_state_dict.items() + f"{module_name.replace('base_model.model', pipeline.unet_name)}.{module_name}": param + for module_name, param in diffusers_state_dict.items() } pipeline.load_lora_weights(diffusers_state_dict) pipeline.fuse_lora() @@ -1343,9 +1327,6 @@ def compute_time_ids(original_size, crops_coords_top_left): accelerator.wait_for_everyone() if accelerator.is_main_process: unet = accelerator.unwrap_model(unet) - # unet.save_pretrained(args.output_dir) - # kohya_state_dict = get_module_kohya_state_dict(unet, "lora_unet", dtype=unet.dtype) - # save_file(kohya_state_dict, os.path.join(args.output_dir, "pytorch_lora_weights.safetensors")) peft_state_dict = get_peft_model_state_dict(unet, adapter_name="default") diffusers_state_dict = convert_state_dict_to_diffusers(peft_state_dict) StableDiffusionXLPipeline.save_lora_weights(args.output_dir, diffusers_state_dict) From ca076c78f0f6f48b9de0ae13f9e5011cafdebb25 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 11:36:34 +0530 Subject: [PATCH 25/95] fix serialization --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index c915864b3ba7..5d2be09cd677 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1329,6 +1329,10 @@ def compute_time_ids(original_size, crops_coords_top_left): unet = accelerator.unwrap_model(unet) peft_state_dict = get_peft_model_state_dict(unet, adapter_name="default") diffusers_state_dict = convert_state_dict_to_diffusers(peft_state_dict) + diffusers_state_dict = { + f"{module_name.replace('base_model.model', '')}.{module_name}": param + for module_name, param in diffusers_state_dict.items() + } StableDiffusionXLPipeline.save_lora_weights(args.output_dir, diffusers_state_dict) if args.push_to_hub: From 808f61ead9e6a72a7f0db85d3a4b306f23636361 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 11:43:00 +0530 Subject: [PATCH 26/95] fix --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 5d2be09cd677..2daf7212c70c 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1330,7 +1330,7 @@ def compute_time_ids(original_size, crops_coords_top_left): peft_state_dict = get_peft_model_state_dict(unet, adapter_name="default") diffusers_state_dict = convert_state_dict_to_diffusers(peft_state_dict) diffusers_state_dict = { - f"{module_name.replace('base_model.model', '')}.{module_name}": param + f"{module_name.replace('base_model.model.', '')}.{module_name}": param for module_name, param in diffusers_state_dict.items() } StableDiffusionXLPipeline.save_lora_weights(args.output_dir, diffusers_state_dict) From 842df25c101e5e8c237cc1b1996bc94ca3bfcb2f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 13:23:18 +0530 Subject: [PATCH 27/95] add test --- examples/test_examples.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/examples/test_examples.py b/examples/test_examples.py index 89e866231e89..5c7305a45b29 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -1680,3 +1680,29 @@ def test_text_to_image_lora_sdxl_with_text_encoder(self): k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys ) self.assertTrue(starts_with_unet) + + def test_text_to_image_lcm__lora_sdxl(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/consistency_distillation/train_lcm_distill_lora_sdxl.py + --pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) From 002767361cad9dc61ef142a6b347cf70b0b77647 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 13:28:15 +0530 Subject: [PATCH 28/95] add peft dependency. --- docker/diffusers-pytorch-cpu/Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docker/diffusers-pytorch-cpu/Dockerfile b/docker/diffusers-pytorch-cpu/Dockerfile index 127c61a719c5..19ddc2ec5feb 100644 --- a/docker/diffusers-pytorch-cpu/Dockerfile +++ b/docker/diffusers-pytorch-cpu/Dockerfile @@ -40,6 +40,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \ numpy \ scipy \ tensorboard \ - transformers + transformers \ + peft CMD ["/bin/bash"] From c625553234be11f0adc3345222389bef3bdf57f5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 13:32:05 +0530 Subject: [PATCH 29/95] add: peft --- .github/workflows/pr_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index f7d9dde5258d..8de334466365 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -113,7 +113,7 @@ jobs: - name: Run example PyTorch CPU tests if: ${{ matrix.config.framework == 'pytorch_examples' }} run: | - python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \ + pip install peft && python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \ --make-reports=tests_${{ matrix.config.report }} \ examples/test_examples.py From c5317ff3b30935ce4964f403ee92850ca5054e6a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 13:33:36 +0530 Subject: [PATCH 30/95] remove peft --- docker/diffusers-pytorch-cpu/Dockerfile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docker/diffusers-pytorch-cpu/Dockerfile b/docker/diffusers-pytorch-cpu/Dockerfile index 19ddc2ec5feb..433c8c609885 100644 --- a/docker/diffusers-pytorch-cpu/Dockerfile +++ b/docker/diffusers-pytorch-cpu/Dockerfile @@ -40,7 +40,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \ numpy \ scipy \ tensorboard \ - transformers \ - peft + transformers CMD ["/bin/bash"] From 6a690abd24fef16d7f67ed78734d7cbf61b4ad44 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 13:44:10 +0530 Subject: [PATCH 31/95] autocast device determination from accelerator --- .../train_lcm_distill_lora_sdxl.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 2daf7212c70c..89a367ef7c30 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1201,7 +1201,9 @@ def compute_time_ids(original_size, crops_coords_top_left): # Get teacher model prediction on noisy_latents and conditional embedding # Notice that we're disabling the adapter layers within the `unet` and then it becomes a # regular teacher. This way, we don't have to separately initialize a teacher UNet. - with torch.no_grad() and torch.autocast("cuda", dtype=weight_dtype) and unet.disable_adapter(): + with torch.no_grad() and torch.autocast( + str(accelerator.device), dtype=weight_dtype + ) and unet.disable_adapter(): cond_teacher_output = unet( noisy_model_input.to(weight_dtype), start_timesteps, @@ -1248,14 +1250,13 @@ def compute_time_ids(original_size, crops_coords_top_left): x_prev = solver.ddim_step(pred_x0, pred_noise, index) # Get target LCM prediction on x_prev, w, c, t_n - with torch.no_grad(): - with torch.autocast("cuda", dtype=weight_dtype): - target_noise_pred = unet( - x_prev.float(), - timesteps, - encoder_hidden_states=prompt_embeds.float(), - added_cond_kwargs=encoded_text, - ).sample + with torch.no_grad() and torch.autocast(str(accelerator.device), dtype=weight_dtype): + target_noise_pred = unet( + x_prev.float(), + timesteps, + encoder_hidden_states=prompt_embeds.float(), + added_cond_kwargs=encoded_text, + ).sample pred_x_0 = predicted_origin( target_noise_pred, timesteps, From 8c4eaf67158224999d94a9e14ca0e6e94f2a2e53 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 13:49:22 +0530 Subject: [PATCH 32/95] autocast --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 89a367ef7c30..c509cb54e80d 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1229,6 +1229,8 @@ def compute_time_ids(original_size, crops_coords_top_left): uncond_added_conditions = copy.deepcopy(encoded_text) # Get teacher model prediction on noisy_latents and unconditional embedding uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds + for k, v in uncond_added_conditions.items(): + print("From training script:", k, v.shape) uncond_teacher_output = unet( noisy_model_input.to(weight_dtype), start_timesteps, From cece78192179ca715c99d663a92f8a9905f09126 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 13:55:24 +0530 Subject: [PATCH 33/95] reduce lora rank. --- examples/test_examples.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/test_examples.py b/examples/test_examples.py index 5c7305a45b29..a58598bb498e 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -1681,13 +1681,14 @@ def test_text_to_image_lora_sdxl_with_text_encoder(self): ) self.assertTrue(starts_with_unet) - def test_text_to_image_lcm__lora_sdxl(self): + def test_text_to_image_lcm_lora_sdxl(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" examples/consistency_distillation/train_lcm_distill_lora_sdxl.py --pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe --dataset_name hf-internal-testing/dummy_image_text_data --resolution 64 + --lora_rank 4 --train_batch_size 1 --gradient_accumulation_steps 1 --max_train_steps 2 From beb8aa2c052d156c58bf3e73f8500c77566f2c16 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 19:10:46 +0530 Subject: [PATCH 34/95] remove unneeded space --- docker/diffusers-pytorch-cpu/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/diffusers-pytorch-cpu/Dockerfile b/docker/diffusers-pytorch-cpu/Dockerfile index 433c8c609885..127c61a719c5 100644 --- a/docker/diffusers-pytorch-cpu/Dockerfile +++ b/docker/diffusers-pytorch-cpu/Dockerfile @@ -40,6 +40,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \ numpy \ scipy \ tensorboard \ - transformers + transformers CMD ["/bin/bash"] From 33cb9d0355ef8a3656cb325ed536f402e3bec251 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Nov 2023 19:16:37 +0530 Subject: [PATCH 35/95] Apply suggestions from code review Co-authored-by: Suraj Patil --- .../train_lcm_distill_lora_sdxl.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index c509cb54e80d..9405906ebd62 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# Copyright 2023 The LCM team and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -550,11 +550,6 @@ def parse_args(): " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) - parser.add_argument( - "--cast_teacher_unet", - action="store_true", - help="Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.", - ) # ----Training Optimizations---- parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." @@ -739,16 +734,11 @@ def main(args): revision=args.teacher_revision, ) - # 5. Load teacher U-Net from SD-XL checkpoint - # teacher_unet = UNet2DConditionModel.from_pretrained( - # args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision - # ) # 6. Freeze teacher vae, text_encoders, and teacher_unet vae.requires_grad_(False) text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) - # teacher_unet.requires_grad_(False) # 7. Create online (`unet`) student U-Net. unet = UNet2DConditionModel.from_pretrained( @@ -806,10 +796,6 @@ def main(args): text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) - # Move teacher_unet to device, optionally cast to weight_dtype - # teacher_unet.to(accelerator.device) - # if args.cast_teacher_unet: - # teacher_unet.to(dtype=weight_dtype) # Also move the alpha and sigma noise schedules to accelerator.device. alpha_schedule = alpha_schedule.to(accelerator.device) @@ -853,7 +839,6 @@ def load_model_hook(models, input_dir): "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) unet.enable_xformers_memory_efficient_attention() - # teacher_unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") From 795cc9f9199279f7f308a57377f21fe192a1788c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 19:20:01 +0530 Subject: [PATCH 36/95] style --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 9405906ebd62..d7eff52967f3 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -734,7 +734,6 @@ def main(args): revision=args.teacher_revision, ) - # 6. Freeze teacher vae, text_encoders, and teacher_unet vae.requires_grad_(False) text_encoder_one.requires_grad_(False) @@ -796,7 +795,6 @@ def main(args): text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) - # Also move the alpha and sigma noise schedules to accelerator.device. alpha_schedule = alpha_schedule.to(accelerator.device) sigma_schedule = sigma_schedule.to(accelerator.device) From 042f3578a5482b2aa4cd201b25046488647be713 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 19:25:41 +0530 Subject: [PATCH 37/95] remove prompt dropout. --- .../train_lcm_distill_lora_sdxl.py | 30 ++++--------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index d7eff52967f3..79fa1350eb27 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -476,12 +476,6 @@ def parse_args(): parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") # ----Diffusion Training Arguments---- - parser.add_argument( - "--proportion_empty_prompts", - type=float, - default=0.0, - help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", - ) # ----Latent Consistency Distillation (LCD) Specific Arguments---- parser.add_argument( "--w_min", @@ -593,21 +587,16 @@ def parse_args(): if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank - if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: - raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") - return args # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt -def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True): +def encode_prompt(prompt_batch, text_encoders, tokenizers, is_train=True): prompt_embeds_list = [] captions = [] for caption in prompt_batch: - if random.random() < proportion_empty_prompts: - captions.append("") - elif isinstance(caption, str): + if isinstance(caption, str): captions.append(caption) elif isinstance(caption, (list, np.ndarray)): # take a random caption if there are multiple @@ -982,9 +971,7 @@ def collate_fn(examples): # 14. Embeddings for the UNet. # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids - def compute_embeddings( - prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True - ): + def compute_embeddings(prompt_batch, original_sizes, crop_coords, text_encoders, tokenizers, is_train=True): def compute_time_ids(original_size, crops_coords_top_left): target_size = (args.resolution, args.resolution) add_time_ids = list(original_size + crops_coords_top_left + target_size) @@ -992,9 +979,7 @@ def compute_time_ids(original_size, crops_coords_top_left): add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) return add_time_ids - prompt_embeds, pooled_prompt_embeds = encode_prompt( - prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train - ) + prompt_embeds, pooled_prompt_embeds = encode_prompt(prompt_batch, text_encoders, tokenizers, is_train) add_text_embeds = pooled_prompt_embeds add_time_ids = torch.cat([compute_time_ids(s, c) for s, c in zip(original_sizes, crop_coords)]) @@ -1008,12 +993,7 @@ def compute_time_ids(original_size, crops_coords_top_left): text_encoders = [text_encoder_one, text_encoder_two] tokenizers = [tokenizer_one, tokenizer_two] - compute_embeddings_fn = functools.partial( - compute_embeddings, - proportion_empty_prompts=args.proportion_empty_prompts, - text_encoders=text_encoders, - tokenizers=tokenizers, - ) + compute_embeddings_fn = functools.partial(compute_embeddings, text_encoders=text_encoders, tokenizers=tokenizers) # 14. LR Scheduler creation # Scheduler and math around the number of training steps. From 283af65138568197375588c048e269054d44c8db Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 19:33:39 +0530 Subject: [PATCH 38/95] also save in native diffusers ckpt format. --- .../train_lcm_distill_lora_sdxl.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 79fa1350eb27..e82de375ba6e 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -798,6 +798,15 @@ def save_model_hook(models, weights, output_dir): unet_ = accelerator.unwrap_model(unet) # save weights in peft format to be able to load them back unet_.save_pretrained(output_dir) + # also save the checkpoints in native `diffusers` format so that it can be easily + # be independently loaded via `load_lora_weights()`. + peft_state_dict = get_peft_model_state_dict(unet, adapter_name="default") + diffusers_state_dict = convert_state_dict_to_diffusers(peft_state_dict) + diffusers_state_dict = { + f"{module_name.replace('base_model.model.', '')}.{module_name}": param + for module_name, param in diffusers_state_dict.items() + } + StableDiffusionXLPipeline.save_lora_weights(output_dir, diffusers_state_dict) for _, model in enumerate(models): # make sure to pop weight so that corresponding model is not saved again From 5e099a248deae387e11cf739484f8c8672ce457d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 19:39:27 +0530 Subject: [PATCH 39/95] debug --- src/diffusers/models/embeddings.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index a377ae267411..0d01b8946e4e 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -223,6 +223,7 @@ def __init__( def forward(self, sample, condition=None): if condition is not None: sample = sample + self.cond_proj(condition) + print(f"From embeddings: {sample.shape}, {self.linear_1.weight.shape}, {self.linear_1.bias.shape}") sample = self.linear_1(sample) if self.act is not None: From 71db43a2437d37746ea8292aca9ee3901aab88e9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 19:53:46 +0530 Subject: [PATCH 40/95] debug --- src/diffusers/models/unet_2d_condition.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index f248b243f376..f5cb221a4bec 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -978,6 +978,7 @@ def forward( time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) + print(f"From UNet: {add_embeds.shape}") aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style From e1346d5648f254be5962029a29962bd151e4dd18 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 20:11:59 +0530 Subject: [PATCH 41/95] debug --- examples/test_examples.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/test_examples.py b/examples/test_examples.py index a58598bb498e..c92aac6606aa 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -1643,7 +1643,7 @@ def test_text_to_image_lora_sdxl(self): # make sure the state_dict has the correct naming in the parameters. lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) - is_lora = all("lora" in k for k in lora_state_dict.keys()) + is_lora = all("lora_sayak" in k for k in lora_state_dict.keys()) self.assertTrue(is_lora) def test_text_to_image_lora_sdxl_with_text_encoder(self): From dfcf2340f3b39aebf37b6e3e407d7762ca947b3e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 20:45:49 +0530 Subject: [PATCH 42/95] better formation of the null embeddings. --- .../train_lcm_distill_lora_sdxl.py | 14 ++------------ src/diffusers/models/embeddings.py | 2 +- src/diffusers/models/unet_2d_condition.py | 1 - 3 files changed, 3 insertions(+), 14 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index e82de375ba6e..58b162bc9300 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -63,10 +63,6 @@ logger = get_logger(__name__) -MAX_SEQ_LENGTH = 77 -EMBEDDING_DIM = 2048 -POOLED_PROJECTION_DIM = 1280 - DATASET_NAME_MAPPING = { "lambdalabs/pokemon-blip-captions": ("image", "text"), } @@ -1192,17 +1188,11 @@ def compute_time_ids(original_size, crops_coords_top_left): ) # Create uncond embeds for classifier free guidance - uncond_prompt_embeds = torch.zeros(cond_teacher_output.shape[0], MAX_SEQ_LENGTH, EMBEDDING_DIM).to( - accelerator.device - ) - uncond_pooled_prompt_embeds = torch.zeros(cond_teacher_output.shape[0], POOLED_PROJECTION_DIM).to( - accelerator.device - ) + uncond_prompt_embeds = torch.zeros_like(prompt_embeds) + uncond_pooled_prompt_embeds = torch.zeros_like(encoded_text["text_embeds"]) uncond_added_conditions = copy.deepcopy(encoded_text) # Get teacher model prediction on noisy_latents and unconditional embedding uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds - for k, v in uncond_added_conditions.items(): - print("From training script:", k, v.shape) uncond_teacher_output = unet( noisy_model_input.to(weight_dtype), start_timesteps, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 0d01b8946e4e..e918fd628287 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -223,7 +223,7 @@ def __init__( def forward(self, sample, condition=None): if condition is not None: sample = sample + self.cond_proj(condition) - print(f"From embeddings: {sample.shape}, {self.linear_1.weight.shape}, {self.linear_1.bias.shape}") + sample = self.linear_1(sample) if self.act is not None: diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index f5cb221a4bec..f248b243f376 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -978,7 +978,6 @@ def forward( time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) - print(f"From UNet: {add_embeds.shape}") aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style From 5ce6cc19eaa1559e0f20a7b0d8c563ae3915cc33 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 20:46:28 +0530 Subject: [PATCH 43/95] remove space. --- src/diffusers/models/embeddings.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index e918fd628287..a377ae267411 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -223,7 +223,6 @@ def __init__( def forward(self, sample, condition=None): if condition is not None: sample = sample + self.cond_proj(condition) - sample = self.linear_1(sample) if self.act is not None: From 7ee9d5d940791faca7d20e9233dd9e329ab0980f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 20:57:57 +0530 Subject: [PATCH 44/95] autocast fixes. --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 58b162bc9300..3f67d23e5292 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1170,7 +1170,7 @@ def compute_time_ids(original_size, crops_coords_top_left): # Notice that we're disabling the adapter layers within the `unet` and then it becomes a # regular teacher. This way, we don't have to separately initialize a teacher UNet. with torch.no_grad() and torch.autocast( - str(accelerator.device), dtype=weight_dtype + str(accelerator.device), dtype=weight_dtype, enabled="cuda" in str(accelerator.device) ) and unet.disable_adapter(): cond_teacher_output = unet( noisy_model_input.to(weight_dtype), @@ -1214,7 +1214,9 @@ def compute_time_ids(original_size, crops_coords_top_left): x_prev = solver.ddim_step(pred_x0, pred_noise, index) # Get target LCM prediction on x_prev, w, c, t_n - with torch.no_grad() and torch.autocast(str(accelerator.device), dtype=weight_dtype): + with torch.no_grad() and torch.autocast( + str(accelerator.device), dtype=weight_dtype, enabled="cuda" in str(accelerator.device) + ): target_noise_pred = unet( x_prev.float(), timesteps, From 1b359ae8198a246a35f43fc8c10f459b8b4ecdd1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 21:02:27 +0530 Subject: [PATCH 45/95] autocast fix. --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 3f67d23e5292..964cb6347dd2 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1169,8 +1169,9 @@ def compute_time_ids(original_size, crops_coords_top_left): # Get teacher model prediction on noisy_latents and conditional embedding # Notice that we're disabling the adapter layers within the `unet` and then it becomes a # regular teacher. This way, we don't have to separately initialize a teacher UNet. + using_cuda = "cuda" in str(accelerator.device) with torch.no_grad() and torch.autocast( - str(accelerator.device), dtype=weight_dtype, enabled="cuda" in str(accelerator.device) + str(accelerator.device), dtype=weight_dtype if using_cuda else torch.float32, enabled=using_cuda ) and unet.disable_adapter(): cond_teacher_output = unet( noisy_model_input.to(weight_dtype), @@ -1215,7 +1216,7 @@ def compute_time_ids(original_size, crops_coords_top_left): # Get target LCM prediction on x_prev, w, c, t_n with torch.no_grad() and torch.autocast( - str(accelerator.device), dtype=weight_dtype, enabled="cuda" in str(accelerator.device) + str(accelerator.device), dtype=weight_dtype if using_cuda else torch.float32, enabled=using_cuda ): target_noise_pred = unet( x_prev.float(), From 82b628a3ee92d26c89113fbc90600698b5fab9e6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Nov 2023 21:04:20 +0530 Subject: [PATCH 46/95] hacky --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 964cb6347dd2..f25efd84ea53 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1171,7 +1171,7 @@ def compute_time_ids(original_size, crops_coords_top_left): # regular teacher. This way, we don't have to separately initialize a teacher UNet. using_cuda = "cuda" in str(accelerator.device) with torch.no_grad() and torch.autocast( - str(accelerator.device), dtype=weight_dtype if using_cuda else torch.float32, enabled=using_cuda + str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda ) and unet.disable_adapter(): cond_teacher_output = unet( noisy_model_input.to(weight_dtype), @@ -1216,7 +1216,7 @@ def compute_time_ids(original_size, crops_coords_top_left): # Get target LCM prediction on x_prev, w, c, t_n with torch.no_grad() and torch.autocast( - str(accelerator.device), dtype=weight_dtype if using_cuda else torch.float32, enabled=using_cuda + str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda ): target_noise_pred = unet( x_prev.float(), From 17d5c0ddfb26f90e91323c5c74d450a3c0e47b34 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 16 Nov 2023 08:45:08 +0530 Subject: [PATCH 47/95] remove lora_sayak --- examples/test_examples.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/test_examples.py b/examples/test_examples.py index c92aac6606aa..a58598bb498e 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -1643,7 +1643,7 @@ def test_text_to_image_lora_sdxl(self): # make sure the state_dict has the correct naming in the parameters. lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) - is_lora = all("lora_sayak" in k for k in lora_state_dict.keys()) + is_lora = all("lora" in k for k in lora_state_dict.keys()) self.assertTrue(is_lora) def test_text_to_image_lora_sdxl_with_text_encoder(self): From fea95e0fbbcd2fbc0cad75ca6e68e6a2c4541c38 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Nov 2023 08:50:28 +0530 Subject: [PATCH 48/95] Apply suggestions from code review Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- .../train_lcm_distill_lora_sdxl.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index f25efd84ea53..aa72519474e2 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -761,7 +761,7 @@ def main(args): "time_emb_proj", ], ) - unet = get_peft_model(unet, lora_config) + unet.add_adapter(lora_config) # 9. Handle mixed precision and device placement # For mixed precision training we cast all non-trainable weigths to half-precision @@ -857,7 +857,7 @@ def load_model_hook(models, input_dir): # 12. Optimizer creation optimizer = optimizer_class( - unet.parameters(), + filter(lambda p: p.requires_grad, unet.parameters()), lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -1170,15 +1170,18 @@ def compute_time_ids(original_size, crops_coords_top_left): # Notice that we're disabling the adapter layers within the `unet` and then it becomes a # regular teacher. This way, we don't have to separately initialize a teacher UNet. using_cuda = "cuda" in str(accelerator.device) + unet.disable_adapters() with torch.no_grad() and torch.autocast( str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda - ) and unet.disable_adapter(): + ): cond_teacher_output = unet( noisy_model_input.to(weight_dtype), start_timesteps, encoder_hidden_states=prompt_embeds.to(weight_dtype), added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, ).sample + # re-enable unet adapters + unet.enable_adapters() cond_pred_x0 = predicted_origin( cond_teacher_output, start_timesteps, From 83801a6943c3bf0126e91270596241c96c6720b3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 16 Nov 2023 08:51:06 +0530 Subject: [PATCH 49/95] style --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index aa72519474e2..614c2daed98d 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -36,7 +36,7 @@ from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version -from peft import LoraConfig, get_peft_model, get_peft_model_state_dict +from peft import LoraConfig, get_peft_model_state_dict from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm From 0c5d9348d1dcb514b45abcd77b6beb84ae13d505 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 16 Nov 2023 09:13:54 +0530 Subject: [PATCH 50/95] make log validation leaner. --- .../train_lcm_distill_lora_sdxl.py | 74 +++++++++---------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 614c2daed98d..64146c8f38ac 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -51,7 +51,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler -from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available +from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -96,27 +96,29 @@ def ddim_step(self, pred_x0, pred_noise, timestep_index): return x_prev -def log_validation(vae, unet, args, accelerator, weight_dtype, step): +def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_final_validation=False): logger.info("Running validation... ") - unet = accelerator.unwrap_model(unet) pipeline = StableDiffusionXLPipeline.from_pretrained( args.pretrained_teacher_model, vae=vae, scheduler=LCMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder="scheduler"), revision=args.revision, torch_dtype=weight_dtype, - ) - pipeline = pipeline.to(accelerator.device) + ).to(accelerator.device) pipeline.set_progress_bar_config(disable=True) - peft_state_dict = get_peft_model_state_dict(unet, adapter_name="default") - diffusers_state_dict = convert_state_dict_to_diffusers(peft_state_dict) - diffusers_state_dict = { - f"{module_name.replace('base_model.model', pipeline.unet_name)}.{module_name}": param - for module_name, param in diffusers_state_dict.items() - } - pipeline.load_lora_weights(diffusers_state_dict) + to_load = None + if not is_final_validation: + if unet is None: + raise ValueError("Must provide a `unet` when doing intermediate validation.") + unet = accelerator.unwrap_model(unet) + state_dict = get_peft_model_state_dict(unet) + to_load = state_dict + else: + to_load = args.output_dir + + pipeline.load_lora_weights(to_load) pipeline.fuse_lora() if args.enable_xformers_memory_efficient_attention: @@ -169,8 +171,8 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step): for image in images: image = wandb.Image(image, caption=validation_prompt) formatted_images.append(image) - - tracker.log({"validation": formatted_images}) + logger_name = "test" if is_final_validation else "validation" + tracker.log({logger_name: formatted_images}) else: logger.warn(f"image logging not implemented for {tracker.name}") @@ -792,17 +794,10 @@ def main(args): def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: unet_ = accelerator.unwrap_model(unet) - # save weights in peft format to be able to load them back - unet_.save_pretrained(output_dir) # also save the checkpoints in native `diffusers` format so that it can be easily # be independently loaded via `load_lora_weights()`. - peft_state_dict = get_peft_model_state_dict(unet, adapter_name="default") - diffusers_state_dict = convert_state_dict_to_diffusers(peft_state_dict) - diffusers_state_dict = { - f"{module_name.replace('base_model.model.', '')}.{module_name}": param - for module_name, param in diffusers_state_dict.items() - } - StableDiffusionXLPipeline.save_lora_weights(output_dir, diffusers_state_dict) + state_dict = get_peft_model_state_dict(unet_) + StableDiffusionXLPipeline.save_lora_weights(output_dir, unet_lora_layers=state_dict) for _, model in enumerate(models): # make sure to pop weight so that corresponding model is not saved again @@ -811,7 +806,8 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): # load the LoRA into the model unet_ = accelerator.unwrap_model(unet) - unet_.load_adapter(input_dir, "default", is_trainable=True) + lora_state_dict, network_alphas = StableDiffusionXLPipeline.lora_state_dict(input_dir) + StableDiffusionXLPipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) for _ in range(len(models)): # pop models so that they are not loaded again @@ -856,8 +852,9 @@ def load_model_hook(models, input_dir): optimizer_class = torch.optim.AdamW # 12. Optimizer creation + params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters())) optimizer = optimizer_class( - filter(lambda p: p.requires_grad, unet.parameters()), + params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -1172,7 +1169,7 @@ def compute_time_ids(original_size, crops_coords_top_left): using_cuda = "cuda" in str(accelerator.device) unet.disable_adapters() with torch.no_grad() and torch.autocast( - str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda + str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16 ): cond_teacher_output = unet( noisy_model_input.to(weight_dtype), @@ -1180,8 +1177,6 @@ def compute_time_ids(original_size, crops_coords_top_left): encoder_hidden_states=prompt_embeds.to(weight_dtype), added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, ).sample - # re-enable unet adapters - unet.enable_adapters() cond_pred_x0 = predicted_origin( cond_teacher_output, start_timesteps, @@ -1217,9 +1212,12 @@ def compute_time_ids(original_size, crops_coords_top_left): pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) x_prev = solver.ddim_step(pred_x0, pred_noise, index) + # re-enable unet adapters + unet.enable_adapters() + # Get target LCM prediction on x_prev, w, c, t_n with torch.no_grad() and torch.autocast( - str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda + str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16 ): target_noise_pred = unet( x_prev.float(), @@ -1248,7 +1246,7 @@ def compute_time_ids(original_size, crops_coords_top_left): # Backpropagate on the online student model (`unet`) accelerator.backward(loss) if accelerator.sync_gradients: - accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) @@ -1298,13 +1296,8 @@ def compute_time_ids(original_size, crops_coords_top_left): accelerator.wait_for_everyone() if accelerator.is_main_process: unet = accelerator.unwrap_model(unet) - peft_state_dict = get_peft_model_state_dict(unet, adapter_name="default") - diffusers_state_dict = convert_state_dict_to_diffusers(peft_state_dict) - diffusers_state_dict = { - f"{module_name.replace('base_model.model.', '')}.{module_name}": param - for module_name, param in diffusers_state_dict.items() - } - StableDiffusionXLPipeline.save_lora_weights(args.output_dir, diffusers_state_dict) + unet_lora_state_dict = get_peft_model_state_dict(unet) + StableDiffusionXLPipeline.save_lora_weights(args.output_dir, unet_lora_layers=unet_lora_state_dict) if args.push_to_hub: upload_folder( @@ -1314,6 +1307,13 @@ def compute_time_ids(original_size, crops_coords_top_left): ignore_patterns=["step_*", "epoch_*"], ) + del unet + torch.cuda.empty_cache() + + # Final inference. + if args.validation_steps is not None: + log_validation(vae, args, accelerator, weight_dtype, step=global_step, unet=None, is_final_validation=True) + accelerator.end_training() From 0f42185e2b7ba90e4fe194c9502fffdf65df88b5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 16 Nov 2023 09:34:01 +0530 Subject: [PATCH 51/95] move back enabled in. --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 64146c8f38ac..7402b85eab1c 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1169,7 +1169,7 @@ def compute_time_ids(original_size, crops_coords_top_left): using_cuda = "cuda" in str(accelerator.device) unet.disable_adapters() with torch.no_grad() and torch.autocast( - str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16 + str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda ): cond_teacher_output = unet( noisy_model_input.to(weight_dtype), @@ -1217,7 +1217,7 @@ def compute_time_ids(original_size, crops_coords_top_left): # Get target LCM prediction on x_prev, w, c, t_n with torch.no_grad() and torch.autocast( - str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16 + str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda ): target_noise_pred = unet( x_prev.float(), From 41f192581b1437c777102cf12bf3ef0b3f616e48 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 16 Nov 2023 09:35:34 +0530 Subject: [PATCH 52/95] fix: log_validation call. --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 7402b85eab1c..3b6f42e41cf4 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1283,7 +1283,9 @@ def compute_time_ids(original_size, crops_coords_top_left): logger.info(f"Saved state to {save_path}") if global_step % args.validation_steps == 0: - log_validation(vae, unet, args, accelerator, weight_dtype, global_step) + log_validation( + vae, args, accelerator, weight_dtype, global_step, unet=unet, is_final_validation=False + ) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) From bf5c5d6b2ab65c19fb8ff26a73723e3d05a7b8fc Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 16 Nov 2023 09:42:30 +0530 Subject: [PATCH 53/95] add: checkpointing tests --- examples/test_examples.py | 51 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/examples/test_examples.py b/examples/test_examples.py index a58598bb498e..33d9492ae65d 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -1707,3 +1707,54 @@ def test_text_to_image_lcm_lora_sdxl(self): lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) is_lora = all("lora" in k for k in lora_state_dict.keys()) self.assertTrue(is_lora) + + def test_text_to_image_lcm_lora_sdxl_checkpointing(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/consistency_distillation/train_lcm_distill_lora_sdxl.py + --pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --lora_rank 4 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 7 + --checkpointing_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4", "checkpoint-6"}, + ) + + test_args = f""" + examples/consistency_distillation/train_lcm_distill_lora_sdxl.py + --pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --lora_rank 4 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 9 + --checkpointing_steps 2 + --resume_from_checkpoint latest + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"}, + ) From 5534b0c255a3b2520121a35bad1e133ab7500abd Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 27 Nov 2023 16:41:56 +0530 Subject: [PATCH 54/95] taking my chances to see if disabling autocasting has any effect? --- .../train_lcm_distill_lora_sdxl.py | 122 +++++++++--------- 1 file changed, 62 insertions(+), 60 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 3b6f42e41cf4..447e9c0c063a 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1166,74 +1166,76 @@ def compute_time_ids(original_size, crops_coords_top_left): # Get teacher model prediction on noisy_latents and conditional embedding # Notice that we're disabling the adapter layers within the `unet` and then it becomes a # regular teacher. This way, we don't have to separately initialize a teacher UNet. - using_cuda = "cuda" in str(accelerator.device) + # using_cuda = "cuda" in str(accelerator.device) unet.disable_adapters() - with torch.no_grad() and torch.autocast( - str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda - ): - cond_teacher_output = unet( - noisy_model_input.to(weight_dtype), - start_timesteps, - encoder_hidden_states=prompt_embeds.to(weight_dtype), - added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, - ).sample - cond_pred_x0 = predicted_origin( - cond_teacher_output, - start_timesteps, - noisy_model_input, - noise_scheduler.config.prediction_type, - alpha_schedule, - sigma_schedule, - ) + # with torch.no_grad() and torch.autocast( + # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda + # ): + cond_teacher_output = unet( + # noisy_model_input.to(weight_dtype), + noisy_model_input, + start_timesteps, + encoder_hidden_states=prompt_embeds.to(weight_dtype), + added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, + ).sample + cond_pred_x0 = predicted_origin( + cond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) - # Create uncond embeds for classifier free guidance - uncond_prompt_embeds = torch.zeros_like(prompt_embeds) - uncond_pooled_prompt_embeds = torch.zeros_like(encoded_text["text_embeds"]) - uncond_added_conditions = copy.deepcopy(encoded_text) - # Get teacher model prediction on noisy_latents and unconditional embedding - uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds - uncond_teacher_output = unet( - noisy_model_input.to(weight_dtype), - start_timesteps, - encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), - added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()}, - ).sample - uncond_pred_x0 = predicted_origin( - uncond_teacher_output, - start_timesteps, - noisy_model_input, - noise_scheduler.config.prediction_type, - alpha_schedule, - sigma_schedule, - ) + # Create uncond embeds for classifier free guidance + uncond_prompt_embeds = torch.zeros_like(prompt_embeds) + uncond_pooled_prompt_embeds = torch.zeros_like(encoded_text["text_embeds"]) + uncond_added_conditions = copy.deepcopy(encoded_text) + # Get teacher model prediction on noisy_latents and unconditional embedding + uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds + uncond_teacher_output = unet( + noisy_model_input.to(weight_dtype), + start_timesteps, + encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), + added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()}, + ).sample + uncond_pred_x0 = predicted_origin( + uncond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) - # Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) - pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) - pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) - x_prev = solver.ddim_step(pred_x0, pred_noise, index) + # Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) + pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) + pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) + x_prev = solver.ddim_step(pred_x0, pred_noise, index) # re-enable unet adapters unet.enable_adapters() # Get target LCM prediction on x_prev, w, c, t_n - with torch.no_grad() and torch.autocast( - str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda - ): - target_noise_pred = unet( - x_prev.float(), - timesteps, - encoder_hidden_states=prompt_embeds.float(), - added_cond_kwargs=encoded_text, - ).sample - pred_x_0 = predicted_origin( - target_noise_pred, - timesteps, - x_prev, - noise_scheduler.config.prediction_type, - alpha_schedule, - sigma_schedule, - ) - target = c_skip * x_prev + c_out * pred_x_0 + # with torch.no_grad() and torch.autocast( + # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda + # ): + target_noise_pred = unet( + x_prev.float(), + timesteps, + # encoder_hidden_states=prompt_embeds.float(), + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=encoded_text, + ).sample + pred_x_0 = predicted_origin( + target_noise_pred, + timesteps, + x_prev, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + target = c_skip * x_prev + c_out * pred_x_0 # Calculate loss if args.loss_type == "l2": From 1da3071f4c545ab4aca3f52ff7e5d307e5ea36a8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 30 Nov 2023 11:42:07 +0530 Subject: [PATCH 55/95] start debugging --- .../train_lcm_distill_lora_sdxl.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 447e9c0c063a..fdd416475266 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -852,7 +852,7 @@ def load_model_hook(models, input_dir): optimizer_class = torch.optim.AdamW # 12. Optimizer creation - params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters())) + params_to_optimize = filter(lambda p: p.requires_grad, unet.parameters()) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, @@ -1167,7 +1167,10 @@ def compute_time_ids(original_size, crops_coords_top_left): # Notice that we're disabling the adapter layers within the `unet` and then it becomes a # regular teacher. This way, we don't have to separately initialize a teacher UNet. # using_cuda = "cuda" in str(accelerator.device) - unet.disable_adapters() + unet.disable_adapters() + params_to_optimize_after_disable = filter(lambda p: p.requires_grad, unet.parameters()) + print("Any difference in trainable params after disable:") + print(set(list(params_to_optimize)).difference(set(list(params_to_optimize_after_disable)))) # with torch.no_grad() and torch.autocast( # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda # ): @@ -1215,6 +1218,9 @@ def compute_time_ids(original_size, crops_coords_top_left): # re-enable unet adapters unet.enable_adapters() + params_to_optimize_after_disable = filter(lambda p: p.requires_grad, unet.parameters()) + print("Any difference in trainable params after disable:") + print(set(list(params_to_optimize)).difference(set(list(params_to_optimize_after_disable)))) # Get target LCM prediction on x_prev, w, c, t_n # with torch.no_grad() and torch.autocast( From bd4d1c43a88e390e166a93cac9ed2f7db56ef357 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 30 Nov 2023 11:55:40 +0530 Subject: [PATCH 56/95] name --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index fdd416475266..390dffddd26f 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1171,6 +1171,8 @@ def compute_time_ids(original_size, crops_coords_top_left): params_to_optimize_after_disable = filter(lambda p: p.requires_grad, unet.parameters()) print("Any difference in trainable params after disable:") print(set(list(params_to_optimize)).difference(set(list(params_to_optimize_after_disable)))) + params_to_optimize_named = [n for n, p in unet.parameters() if p.requires_grad] + print(f"Optimizing parameters: {params_to_optimize_named}") # with torch.no_grad() and torch.autocast( # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda # ): From 26f16c18c7ed31a534da8d63fe065a0868d4048f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 30 Nov 2023 11:57:15 +0530 Subject: [PATCH 57/95] name --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 390dffddd26f..9b8912b01afa 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1171,7 +1171,7 @@ def compute_time_ids(original_size, crops_coords_top_left): params_to_optimize_after_disable = filter(lambda p: p.requires_grad, unet.parameters()) print("Any difference in trainable params after disable:") print(set(list(params_to_optimize)).difference(set(list(params_to_optimize_after_disable)))) - params_to_optimize_named = [n for n, p in unet.parameters() if p.requires_grad] + params_to_optimize_named = [n for n, p in unet.named_parameters() if p.requires_grad] print(f"Optimizing parameters: {params_to_optimize_named}") # with torch.no_grad() and torch.autocast( # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda From 917402720e633000ff7d6eb84af39c52ecc858b0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 30 Nov 2023 12:01:06 +0530 Subject: [PATCH 58/95] name --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 9b8912b01afa..b8b34e3c1ace 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -860,6 +860,8 @@ def load_model_hook(models, input_dir): weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) + params_to_optimize_named = [n for n, p in unet.named_parameters() if p.requires_grad] + print(f"Optimizing parameters: {params_to_optimize_named}") # 13. Dataset creation and data processing # In distributed training, the load_dataset function guarantees that only one local process can concurrently @@ -1171,8 +1173,6 @@ def compute_time_ids(original_size, crops_coords_top_left): params_to_optimize_after_disable = filter(lambda p: p.requires_grad, unet.parameters()) print("Any difference in trainable params after disable:") print(set(list(params_to_optimize)).difference(set(list(params_to_optimize_after_disable)))) - params_to_optimize_named = [n for n, p in unet.named_parameters() if p.requires_grad] - print(f"Optimizing parameters: {params_to_optimize_named}") # with torch.no_grad() and torch.autocast( # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda # ): From 92ba868d57f9088243a38ef931c4eb9351aefc79 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 30 Nov 2023 12:05:11 +0530 Subject: [PATCH 59/95] more debug --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index b8b34e3c1ace..912cd8eb98e2 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1224,6 +1224,9 @@ def compute_time_ids(original_size, crops_coords_top_left): print("Any difference in trainable params after disable:") print(set(list(params_to_optimize)).difference(set(list(params_to_optimize_after_disable)))) + params_to_optimize_named = [n for n, p in unet.named_parameters() if p.requires_grad] + print(f"Optimizing parameters: {params_to_optimize_named}") + # Get target LCM prediction on x_prev, w, c, t_n # with torch.no_grad() and torch.autocast( # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda From 1fba251b12fd0c32ba28b1b7bb24a251e3f24d53 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 30 Nov 2023 12:12:46 +0530 Subject: [PATCH 60/95] more debug --- .../train_lcm_distill_lora_sdxl.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 912cd8eb98e2..0335281e9023 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1173,6 +1173,8 @@ def compute_time_ids(original_size, crops_coords_top_left): params_to_optimize_after_disable = filter(lambda p: p.requires_grad, unet.parameters()) print("Any difference in trainable params after disable:") print(set(list(params_to_optimize)).difference(set(list(params_to_optimize_after_disable)))) + params_to_optimize_named_after_disable = [n for n, p in unet.named_parameters() if p.requires_grad] + print(set(params_to_optimize_named).difference(set(params_to_optimize_named_after_disable))) # with torch.no_grad() and torch.autocast( # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda # ): @@ -1220,12 +1222,11 @@ def compute_time_ids(original_size, crops_coords_top_left): # re-enable unet adapters unet.enable_adapters() - params_to_optimize_after_disable = filter(lambda p: p.requires_grad, unet.parameters()) - print("Any difference in trainable params after disable:") - print(set(list(params_to_optimize)).difference(set(list(params_to_optimize_after_disable)))) - - params_to_optimize_named = [n for n, p in unet.named_parameters() if p.requires_grad] - print(f"Optimizing parameters: {params_to_optimize_named}") + params_to_optimize_after_enable = filter(lambda p: p.requires_grad, unet.parameters()) + print("Any difference in trainable params after enable:") + print(set(list(params_to_optimize)).difference(set(list(params_to_optimize_after_enable)))) + params_to_optimize_named_after_enable = [n for n, p in unet.named_parameters() if p.requires_grad] + print(set(params_to_optimize_named).difference(set(params_to_optimize_named_after_enable))) # Get target LCM prediction on x_prev, w, c, t_n # with torch.no_grad() and torch.autocast( From 3751ca9b39bb1787083662104796d72c5dd8aa2a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 30 Nov 2023 12:15:06 +0530 Subject: [PATCH 61/95] index --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 0335281e9023..2ef96d0b5047 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1172,9 +1172,9 @@ def compute_time_ids(original_size, crops_coords_top_left): unet.disable_adapters() params_to_optimize_after_disable = filter(lambda p: p.requires_grad, unet.parameters()) print("Any difference in trainable params after disable:") - print(set(list(params_to_optimize)).difference(set(list(params_to_optimize_after_disable)))) + print(set(list(params_to_optimize)).difference(set(list(params_to_optimize_after_disable)))[:5]) params_to_optimize_named_after_disable = [n for n, p in unet.named_parameters() if p.requires_grad] - print(set(params_to_optimize_named).difference(set(params_to_optimize_named_after_disable))) + print(set(params_to_optimize_named).difference(set(params_to_optimize_named_after_disable))[:5]) # with torch.no_grad() and torch.autocast( # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda # ): From 63649d355df2e64a8eb7fbafaecba799c59ac07c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 30 Nov 2023 12:18:40 +0530 Subject: [PATCH 62/95] remove index. --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 2ef96d0b5047..0335281e9023 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1172,9 +1172,9 @@ def compute_time_ids(original_size, crops_coords_top_left): unet.disable_adapters() params_to_optimize_after_disable = filter(lambda p: p.requires_grad, unet.parameters()) print("Any difference in trainable params after disable:") - print(set(list(params_to_optimize)).difference(set(list(params_to_optimize_after_disable)))[:5]) + print(set(list(params_to_optimize)).difference(set(list(params_to_optimize_after_disable)))) params_to_optimize_named_after_disable = [n for n, p in unet.named_parameters() if p.requires_grad] - print(set(params_to_optimize_named).difference(set(params_to_optimize_named_after_disable))[:5]) + print(set(params_to_optimize_named).difference(set(params_to_optimize_named_after_disable))) # with torch.no_grad() and torch.autocast( # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda # ): From 05de54229be5e3f5b186f461a29a05bd1caf9a56 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 30 Nov 2023 12:28:44 +0530 Subject: [PATCH 63/95] print length --- examples/consistency_distillation/train_lcm_distill_lora_sdxl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 0335281e9023..3157be062b97 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1172,6 +1172,7 @@ def compute_time_ids(original_size, crops_coords_top_left): unet.disable_adapters() params_to_optimize_after_disable = filter(lambda p: p.requires_grad, unet.parameters()) print("Any difference in trainable params after disable:") + print(len(list(params_to_optimize_after_disable))) print(set(list(params_to_optimize)).difference(set(list(params_to_optimize_after_disable)))) params_to_optimize_named_after_disable = [n for n, p in unet.named_parameters() if p.requires_grad] print(set(params_to_optimize_named).difference(set(params_to_optimize_named_after_disable))) From 5e604a8ad86b4e228da57e93462ef28dba465398 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 30 Nov 2023 12:34:13 +0530 Subject: [PATCH 64/95] print length --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 3157be062b97..ab01eef48f16 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1172,8 +1172,7 @@ def compute_time_ids(original_size, crops_coords_top_left): unet.disable_adapters() params_to_optimize_after_disable = filter(lambda p: p.requires_grad, unet.parameters()) print("Any difference in trainable params after disable:") - print(len(list(params_to_optimize_after_disable))) - print(set(list(params_to_optimize)).difference(set(list(params_to_optimize_after_disable)))) + print(len(list(params_to_optimize)), len(list(params_to_optimize_after_disable))) params_to_optimize_named_after_disable = [n for n, p in unet.named_parameters() if p.requires_grad] print(set(params_to_optimize_named).difference(set(params_to_optimize_named_after_disable))) # with torch.no_grad() and torch.autocast( @@ -1225,7 +1224,7 @@ def compute_time_ids(original_size, crops_coords_top_left): unet.enable_adapters() params_to_optimize_after_enable = filter(lambda p: p.requires_grad, unet.parameters()) print("Any difference in trainable params after enable:") - print(set(list(params_to_optimize)).difference(set(list(params_to_optimize_after_enable)))) + print(len(list(params_to_optimize)), len(list(params_to_optimize_after_enable))) params_to_optimize_named_after_enable = [n for n, p in unet.named_parameters() if p.requires_grad] print(set(params_to_optimize_named).difference(set(params_to_optimize_named_after_enable))) From 8fecdda23cefc30053a511db597b9c017b8fcaa1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 30 Nov 2023 12:40:42 +0530 Subject: [PATCH 65/95] print length --- examples/consistency_distillation/train_lcm_distill_lora_sdxl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index ab01eef48f16..493c51c23f49 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -861,6 +861,7 @@ def load_model_hook(models, input_dir): eps=args.adam_epsilon, ) params_to_optimize_named = [n for n, p in unet.named_parameters() if p.requires_grad] + print(f"Number of params to optimize: {len(list(params_to_optimize))}") print(f"Optimizing parameters: {params_to_optimize_named}") # 13. Dataset creation and data processing From 023866f83064605f42f6f353aafc32c345d61e5d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 1 Dec 2023 07:37:57 +0530 Subject: [PATCH 66/95] move unet.train() after add_adapter() --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 493c51c23f49..e07ed17cfad6 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -730,7 +730,6 @@ def main(args): unet = UNet2DConditionModel.from_pretrained( args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision ) - unet.train() # Check that all trainable models are in full precision low_precision_error_string = ( @@ -764,6 +763,7 @@ def main(args): ], ) unet.add_adapter(lora_config) + unet.train() # 9. Handle mixed precision and device placement # For mixed precision training we cast all non-trainable weigths to half-precision From 07c28de8db9e90389a14a32766df2b904032fe09 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 1 Dec 2023 07:46:24 +0530 Subject: [PATCH 67/95] disable some prints. --- .../train_lcm_distill_lora_sdxl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index e07ed17cfad6..05f892ed47bc 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1174,8 +1174,8 @@ def compute_time_ids(original_size, crops_coords_top_left): params_to_optimize_after_disable = filter(lambda p: p.requires_grad, unet.parameters()) print("Any difference in trainable params after disable:") print(len(list(params_to_optimize)), len(list(params_to_optimize_after_disable))) - params_to_optimize_named_after_disable = [n for n, p in unet.named_parameters() if p.requires_grad] - print(set(params_to_optimize_named).difference(set(params_to_optimize_named_after_disable))) + # params_to_optimize_named_after_disable = [n for n, p in unet.named_parameters() if p.requires_grad] + # print(set(params_to_optimize_named).difference(set(params_to_optimize_named_after_disable))) # with torch.no_grad() and torch.autocast( # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda # ): @@ -1226,8 +1226,8 @@ def compute_time_ids(original_size, crops_coords_top_left): params_to_optimize_after_enable = filter(lambda p: p.requires_grad, unet.parameters()) print("Any difference in trainable params after enable:") print(len(list(params_to_optimize)), len(list(params_to_optimize_after_enable))) - params_to_optimize_named_after_enable = [n for n, p in unet.named_parameters() if p.requires_grad] - print(set(params_to_optimize_named).difference(set(params_to_optimize_named_after_enable))) + # params_to_optimize_named_after_enable = [n for n, p in unet.named_parameters() if p.requires_grad] + # print(set(params_to_optimize_named).difference(set(params_to_optimize_named_after_enable))) # Get target LCM prediction on x_prev, w, c, t_n # with torch.no_grad() and torch.autocast( From c6a61dac312fde5d2f98a8f449f52f53e43d44d8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 1 Dec 2023 07:48:17 +0530 Subject: [PATCH 68/95] enable_adapters() manually. --- .../train_lcm_distill_lora_sdxl.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 05f892ed47bc..ddca1ba4d271 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -762,8 +762,9 @@ def main(args): "time_emb_proj", ], ) - unet.add_adapter(lora_config) unet.train() + unet.add_adapter(lora_config) + unet.enable_adapters() # 9. Handle mixed precision and device placement # For mixed precision training we cast all non-trainable weigths to half-precision @@ -860,9 +861,9 @@ def load_model_hook(models, input_dir): weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) - params_to_optimize_named = [n for n, p in unet.named_parameters() if p.requires_grad] + # params_to_optimize_named = [n for n, p in unet.named_parameters() if p.requires_grad] print(f"Number of params to optimize: {len(list(params_to_optimize))}") - print(f"Optimizing parameters: {params_to_optimize_named}") + # print(f"Optimizing parameters: {params_to_optimize_named}") # 13. Dataset creation and data processing # In distributed training, the load_dataset function guarantees that only one local process can concurrently From ec33085ef351bb5bd972c19e93839007d4cf73b9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 2 Dec 2023 09:36:24 +0530 Subject: [PATCH 69/95] remove prints. --- .../train_lcm_distill_lora_sdxl.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index ddca1ba4d271..0d5ca778b916 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -861,9 +861,6 @@ def load_model_hook(models, input_dir): weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) - # params_to_optimize_named = [n for n, p in unet.named_parameters() if p.requires_grad] - print(f"Number of params to optimize: {len(list(params_to_optimize))}") - # print(f"Optimizing parameters: {params_to_optimize_named}") # 13. Dataset creation and data processing # In distributed training, the load_dataset function guarantees that only one local process can concurrently @@ -1172,11 +1169,6 @@ def compute_time_ids(original_size, crops_coords_top_left): # regular teacher. This way, we don't have to separately initialize a teacher UNet. # using_cuda = "cuda" in str(accelerator.device) unet.disable_adapters() - params_to_optimize_after_disable = filter(lambda p: p.requires_grad, unet.parameters()) - print("Any difference in trainable params after disable:") - print(len(list(params_to_optimize)), len(list(params_to_optimize_after_disable))) - # params_to_optimize_named_after_disable = [n for n, p in unet.named_parameters() if p.requires_grad] - # print(set(params_to_optimize_named).difference(set(params_to_optimize_named_after_disable))) # with torch.no_grad() and torch.autocast( # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda # ): @@ -1224,12 +1216,7 @@ def compute_time_ids(original_size, crops_coords_top_left): # re-enable unet adapters unet.enable_adapters() - params_to_optimize_after_enable = filter(lambda p: p.requires_grad, unet.parameters()) - print("Any difference in trainable params after enable:") - print(len(list(params_to_optimize)), len(list(params_to_optimize_after_enable))) - # params_to_optimize_named_after_enable = [n for n, p in unet.named_parameters() if p.requires_grad] - # print(set(params_to_optimize_named).difference(set(params_to_optimize_named_after_enable))) - + # Get target LCM prediction on x_prev, w, c, t_n # with torch.no_grad() and torch.autocast( # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda From ed7969d2073d4629df09f3fba2984441d835e42d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 3 Dec 2023 18:52:38 +0530 Subject: [PATCH 70/95] some changes. --- .../train_lcm_distill_lora_sdxl.py | 66 +++++++++---------- 1 file changed, 31 insertions(+), 35 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 0d5ca778b916..7e04bdeb139e 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -721,7 +721,7 @@ def main(args): revision=args.teacher_revision, ) - # 6. Freeze teacher vae, text_encoders, and teacher_unet + # 6. Freeze teacher vae, text_encoders. vae.requires_grad_(False) text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) @@ -730,6 +730,7 @@ def main(args): unet = UNet2DConditionModel.from_pretrained( args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision ) + unet.requires_grad_(False) # Check that all trainable models are in full precision low_precision_error_string = ( @@ -742,6 +743,22 @@ def main(args): f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" ) + # 9. Handle mixed precision and device placement + # For mixed precision training we cast all non-trainable weigths to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + unet.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=torch.float32) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. lora_config = LoraConfig( r=args.lora_rank, @@ -762,26 +779,7 @@ def main(args): "time_emb_proj", ], ) - unet.train() unet.add_adapter(lora_config) - unet.enable_adapters() - - # 9. Handle mixed precision and device placement - # For mixed precision training we cast all non-trainable weigths to half-precision - # as these weights are only used for inference, keeping weights in full precision is not required. - weight_dtype = torch.float32 - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - # Move unet, vae and text_encoder to device and cast to weight_dtype - # The VAE is in float32 to avoid NaN losses. - vae.to(accelerator.device) - if args.pretrained_vae_model_name_or_path is not None: - vae.to(dtype=weight_dtype) - text_encoder_one.to(accelerator.device, dtype=weight_dtype) - text_encoder_two.to(accelerator.device, dtype=weight_dtype) # Also move the alpha and sigma noise schedules to accelerator.device. alpha_schedule = alpha_schedule.to(accelerator.device) @@ -853,7 +851,8 @@ def load_model_hook(models, input_dir): optimizer_class = torch.optim.AdamW # 12. Optimizer creation - params_to_optimize = filter(lambda p: p.requires_grad, unet.parameters()) + params_to_optimize = [p.to(torch.float32) for p in unet.parameters() if p.requires_grad] + print(f"params_to_optimize: {params_to_optimize}") optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, @@ -1085,27 +1084,21 @@ def compute_time_ids(original_size, crops_coords_top_left): disable=not accelerator.is_local_main_process, ) + unet.train() for epoch in range(first_epoch, args.num_train_epochs): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): - image, text, orig_size, crop_coords = ( + pixel_values, text, orig_size, crop_coords = ( batch["pixel_values"], batch["captions"], batch["original_sizes"], batch["crop_top_lefts"], ) - image = image.to(accelerator.device, non_blocking=True) encoded_text = compute_embeddings_fn(text, orig_size, crop_coords) - if args.pretrained_vae_model_name_or_path is not None: - pixel_values = image.to(dtype=weight_dtype) - if vae.dtype != weight_dtype: - vae.to(dtype=weight_dtype) - else: - pixel_values = image - # encode pixel values with batch size of at most 8 + pixel_values = pixel_values.to(dtype=vae.dtype) latents = [] for i in range(0, pixel_values.shape[0], args.encode_batch_size): latents.append(vae.encode(pixel_values[i : i + args.encode_batch_size]).latent_dist.sample()) @@ -1148,7 +1141,7 @@ def compute_time_ids(original_size, crops_coords_top_left): noise_pred = unet( noisy_model_input, start_timesteps, - encoder_hidden_states=prompt_embeds.float(), + encoder_hidden_states=prompt_embeds, added_cond_kwargs=encoded_text, ).sample @@ -1168,7 +1161,9 @@ def compute_time_ids(original_size, crops_coords_top_left): # Notice that we're disabling the adapter layers within the `unet` and then it becomes a # regular teacher. This way, we don't have to separately initialize a teacher UNet. # using_cuda = "cuda" in str(accelerator.device) - unet.disable_adapters() + unet.disable_adapters() + params_to_optimize_disabled = [p for p in unet.parameters() if p.requires_grad] + print(f"params_to_optimize after disabled: {params_to_optimize_disabled}") # with torch.no_grad() and torch.autocast( # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda # ): @@ -1216,15 +1211,16 @@ def compute_time_ids(original_size, crops_coords_top_left): # re-enable unet adapters unet.enable_adapters() - + params_to_optimize_enabled = [p for p in unet.parameters() if p.requires_grad] + print(f"params_to_optimize after enabled: {params_to_optimize_enabled}") + # Get target LCM prediction on x_prev, w, c, t_n # with torch.no_grad() and torch.autocast( # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda # ): target_noise_pred = unet( - x_prev.float(), + x_prev, timesteps, - # encoder_hidden_states=prompt_embeds.float(), encoder_hidden_states=prompt_embeds, added_cond_kwargs=encoded_text, ).sample From 8c549e4b3073c4e857cb4a75969d3b84e5620dd3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 3 Dec 2023 18:58:01 +0530 Subject: [PATCH 71/95] fix params_to_optimize --- .../train_lcm_distill_lora_sdxl.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 7e04bdeb139e..9e3cd8f72674 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -851,8 +851,12 @@ def load_model_hook(models, input_dir): optimizer_class = torch.optim.AdamW # 12. Optimizer creation - params_to_optimize = [p.to(torch.float32) for p in unet.parameters() if p.requires_grad] - print(f"params_to_optimize: {params_to_optimize}") + params_to_optimize = [] + for param in unet.parameters(): + if param.requires_grad: + param.data = param.to(torch.float32) + params_to_optimize.append(param) + print(f"params_to_optimize: {len(params_to_optimize)}") optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, From 9446066635d9ba1958d62f563356a4b9ee5ac998 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 3 Dec 2023 19:06:32 +0530 Subject: [PATCH 72/95] more fixes --- .../train_lcm_distill_lora_sdxl.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 9e3cd8f72674..685a32823f79 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1167,15 +1167,14 @@ def compute_time_ids(original_size, crops_coords_top_left): # using_cuda = "cuda" in str(accelerator.device) unet.disable_adapters() params_to_optimize_disabled = [p for p in unet.parameters() if p.requires_grad] - print(f"params_to_optimize after disabled: {params_to_optimize_disabled}") + print(f"params_to_optimize after disabled: {len(params_to_optimize_disabled)}") # with torch.no_grad() and torch.autocast( # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda # ): cond_teacher_output = unet( - # noisy_model_input.to(weight_dtype), noisy_model_input, start_timesteps, - encoder_hidden_states=prompt_embeds.to(weight_dtype), + encoder_hidden_states=prompt_embeds, added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, ).sample cond_pred_x0 = predicted_origin( @@ -1194,7 +1193,7 @@ def compute_time_ids(original_size, crops_coords_top_left): # Get teacher model prediction on noisy_latents and unconditional embedding uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds uncond_teacher_output = unet( - noisy_model_input.to(weight_dtype), + noisy_model_input, start_timesteps, encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()}, @@ -1216,7 +1215,7 @@ def compute_time_ids(original_size, crops_coords_top_left): # re-enable unet adapters unet.enable_adapters() params_to_optimize_enabled = [p for p in unet.parameters() if p.requires_grad] - print(f"params_to_optimize after enabled: {params_to_optimize_enabled}") + print(f"params_to_optimize after enabled: {len(params_to_optimize_enabled)}") # Get target LCM prediction on x_prev, w, c, t_n # with torch.no_grad() and torch.autocast( @@ -1226,7 +1225,7 @@ def compute_time_ids(original_size, crops_coords_top_left): x_prev, timesteps, encoder_hidden_states=prompt_embeds, - added_cond_kwargs=encoded_text, + added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, ).sample pred_x_0 = predicted_origin( target_noise_pred, From 0153665f63ab9b4c9543bf4b8f1089e0d4a5a0c5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 3 Dec 2023 19:10:23 +0530 Subject: [PATCH 73/95] debug --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 685a32823f79..772ee42d5a13 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1221,6 +1221,11 @@ def compute_time_ids(original_size, crops_coords_top_left): # with torch.no_grad() and torch.autocast( # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda # ): + print(f"x_prev: {x_prev.dtype}") + print(f"timesteps: {timesteps.dtype}") + print(f"prompt_embeds: {prompt_embeds}") + for k, v in encoded_text.items(): + print(k, v.dtype) target_noise_pred = unet( x_prev, timesteps, From b9891ffbe81b918d18f7ea1e5f50b2c4127994e9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 3 Dec 2023 19:13:24 +0530 Subject: [PATCH 74/95] debug --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 772ee42d5a13..84f68872abe8 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1211,6 +1211,7 @@ def compute_time_ids(original_size, crops_coords_top_left): pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) x_prev = solver.ddim_step(pred_x0, pred_noise, index) + x_prev = x_prev.to(unet.dtype) # re-enable unet adapters unet.enable_adapters() @@ -1221,11 +1222,6 @@ def compute_time_ids(original_size, crops_coords_top_left): # with torch.no_grad() and torch.autocast( # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda # ): - print(f"x_prev: {x_prev.dtype}") - print(f"timesteps: {timesteps.dtype}") - print(f"prompt_embeds: {prompt_embeds}") - for k, v in encoded_text.items(): - print(k, v.dtype) target_noise_pred = unet( x_prev, timesteps, From b11b0a6d207d4f66bba7762e30a2efbba71f1e75 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 3 Dec 2023 19:15:13 +0530 Subject: [PATCH 75/95] remove print --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 84f68872abe8..26ea92e23119 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -856,7 +856,6 @@ def load_model_hook(models, input_dir): if param.requires_grad: param.data = param.to(torch.float32) params_to_optimize.append(param) - print(f"params_to_optimize: {len(params_to_optimize)}") optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, @@ -1166,8 +1165,6 @@ def compute_time_ids(original_size, crops_coords_top_left): # regular teacher. This way, we don't have to separately initialize a teacher UNet. # using_cuda = "cuda" in str(accelerator.device) unet.disable_adapters() - params_to_optimize_disabled = [p for p in unet.parameters() if p.requires_grad] - print(f"params_to_optimize after disabled: {len(params_to_optimize_disabled)}") # with torch.no_grad() and torch.autocast( # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda # ): @@ -1215,8 +1212,6 @@ def compute_time_ids(original_size, crops_coords_top_left): # re-enable unet adapters unet.enable_adapters() - params_to_optimize_enabled = [p for p in unet.parameters() if p.requires_grad] - print(f"params_to_optimize after enabled: {len(params_to_optimize_enabled)}") # Get target LCM prediction on x_prev, w, c, t_n # with torch.no_grad() and torch.autocast( From 539bda398308cb0b3991353f3dc8faca66f07949 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 3 Dec 2023 20:31:47 +0530 Subject: [PATCH 76/95] disable grad for certain contexts. --- .../train_lcm_distill_lora_sdxl.py | 117 +++++++++--------- 1 file changed, 56 insertions(+), 61 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 26ea92e23119..73e56293f372 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1163,75 +1163,70 @@ def compute_time_ids(original_size, crops_coords_top_left): # Get teacher model prediction on noisy_latents and conditional embedding # Notice that we're disabling the adapter layers within the `unet` and then it becomes a # regular teacher. This way, we don't have to separately initialize a teacher UNet. - # using_cuda = "cuda" in str(accelerator.device) unet.disable_adapters() - # with torch.no_grad() and torch.autocast( - # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda - # ): - cond_teacher_output = unet( - noisy_model_input, - start_timesteps, - encoder_hidden_states=prompt_embeds, - added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, - ).sample - cond_pred_x0 = predicted_origin( - cond_teacher_output, - start_timesteps, - noisy_model_input, - noise_scheduler.config.prediction_type, - alpha_schedule, - sigma_schedule, - ) + with torch.no_grad(): + cond_teacher_output = unet( + noisy_model_input, + start_timesteps, + encoder_hidden_states=prompt_embeds, + added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, + ).sample + cond_pred_x0 = predicted_origin( + cond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) - # Create uncond embeds for classifier free guidance - uncond_prompt_embeds = torch.zeros_like(prompt_embeds) - uncond_pooled_prompt_embeds = torch.zeros_like(encoded_text["text_embeds"]) - uncond_added_conditions = copy.deepcopy(encoded_text) - # Get teacher model prediction on noisy_latents and unconditional embedding - uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds - uncond_teacher_output = unet( - noisy_model_input, - start_timesteps, - encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), - added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()}, - ).sample - uncond_pred_x0 = predicted_origin( - uncond_teacher_output, - start_timesteps, - noisy_model_input, - noise_scheduler.config.prediction_type, - alpha_schedule, - sigma_schedule, - ) + # Create uncond embeds for classifier free guidance + uncond_prompt_embeds = torch.zeros_like(prompt_embeds) + uncond_pooled_prompt_embeds = torch.zeros_like(encoded_text["text_embeds"]) + uncond_added_conditions = copy.deepcopy(encoded_text) + # Get teacher model prediction on noisy_latents and unconditional embedding + uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds + uncond_teacher_output = unet( + noisy_model_input, + start_timesteps, + encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), + added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()}, + ).sample + uncond_pred_x0 = predicted_origin( + uncond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) - # Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) - pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) - pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) - x_prev = solver.ddim_step(pred_x0, pred_noise, index) - x_prev = x_prev.to(unet.dtype) + # Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) + pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) + pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) + x_prev = solver.ddim_step(pred_x0, pred_noise, index) + x_prev = x_prev.to(unet.dtype) # re-enable unet adapters unet.enable_adapters() # Get target LCM prediction on x_prev, w, c, t_n - # with torch.no_grad() and torch.autocast( - # str(accelerator.device), dtype=weight_dtype if using_cuda else torch.bfloat16, enabled=using_cuda - # ): - target_noise_pred = unet( - x_prev, - timesteps, - encoder_hidden_states=prompt_embeds, - added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, - ).sample - pred_x_0 = predicted_origin( - target_noise_pred, - timesteps, - x_prev, - noise_scheduler.config.prediction_type, - alpha_schedule, - sigma_schedule, - ) - target = c_skip * x_prev + c_out * pred_x_0 + with torch.no_grad(): + target_noise_pred = unet( + x_prev, + timesteps, + encoder_hidden_states=prompt_embeds, + added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, + ).sample + pred_x_0 = predicted_origin( + target_noise_pred, + timesteps, + x_prev, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + target = c_skip * x_prev + c_out * pred_x_0 # Calculate loss if args.loss_type == "l2": From d5a40cde0e0ff3a1d33a53561b3c222780de9e94 Mon Sep 17 00:00:00 2001 From: Fabio Rigano <57982783+fabiorigano@users.noreply.github.com> Date: Thu, 7 Dec 2023 17:40:39 +0100 Subject: [PATCH 77/95] Add support for IPAdapterFull (#5911) * Add support for IPAdapterFull Co-authored-by: Patrick von Platen --------- Co-authored-by: YiYi Xu Co-authored-by: Patrick von Platen --- .../en/using-diffusers/loading_adapters.md | 63 +++++++++++++++++++ src/diffusers/loaders/unet.py | 29 ++++++++- src/diffusers/models/embeddings.py | 12 ++++ .../test_ip_adapter_stable_diffusion.py | 19 ++++++ 4 files changed, 122 insertions(+), 1 deletion(-) diff --git a/docs/source/en/using-diffusers/loading_adapters.md b/docs/source/en/using-diffusers/loading_adapters.md index c14b38a9dd89..d9d4a675dd37 100644 --- a/docs/source/en/using-diffusers/loading_adapters.md +++ b/docs/source/en/using-diffusers/loading_adapters.md @@ -485,6 +485,69 @@ image.save("sdxl_t2i.png") +You can use the IP-Adapter face model to apply specific faces to your images. It is an effective way to maintain consistent characters in your image generations. +Weights are loaded with the same method used for the other IP-Adapters. + +```python +# Load ip-adapter-full-face_sd15.bin +pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin") +``` + + + +It is recommended to use `DDIMScheduler` and `EulerDiscreteScheduler` for face model. + + + + +```python +import torch +from diffusers import StableDiffusionPipeline, DDIMScheduler +from diffusers.utils import load_image + +noise_scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1 +) + +pipeline = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + torch_dtype=torch.float16, + scheduler=noise_scheduler, +).to("cuda") + +pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin") + +pipeline.set_ip_adapter_scale(0.7) + +image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png") + +generator = torch.Generator(device="cpu").manual_seed(33) + +image = pipeline( + prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower", + ip_adapter_image=image, + negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", + num_inference_steps=50, num_images_per_prompt=1, width=512, height=704, + generator=generator, +).images[0] +``` + +
+
+ +
input image
+
+
+ +
output image
+
+
### LCM-Lora diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index b155f595e740..7309c3fc709c 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -22,7 +22,7 @@ from huggingface_hub.utils import validate_hf_hub_args from torch import nn -from ..models.embeddings import ImageProjection, Resampler +from ..models.embeddings import ImageProjection, MLPProjection, Resampler from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..utils import ( USE_PEFT_BACKEND, @@ -675,6 +675,9 @@ def _load_ip_adapter_weights(self, state_dict): if "proj.weight" in state_dict["image_proj"]: # IP-Adapter num_image_text_embeds = 4 + elif "proj.3.weight" in state_dict["image_proj"]: + # IP-Adapter Full Face + num_image_text_embeds = 257 # 256 CLIP tokens + 1 CLS token else: # IP-Adapter Plus num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1] @@ -744,8 +747,32 @@ def _load_ip_adapter_weights(self, state_dict): "norm.bias": state_dict["image_proj"]["norm.bias"], } ) + image_projection.load_state_dict(image_proj_state_dict) + del image_proj_state_dict + elif "proj.3.weight" in state_dict["image_proj"]: + clip_embeddings_dim = state_dict["image_proj"]["proj.0.weight"].shape[0] + cross_attention_dim = state_dict["image_proj"]["proj.3.weight"].shape[0] + + image_projection = MLPProjection( + cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim + ) + image_projection.to(dtype=self.dtype, device=self.device) + + # load image projection layer weights + image_proj_state_dict = {} + image_proj_state_dict.update( + { + "ff.net.0.proj.weight": state_dict["image_proj"]["proj.0.weight"], + "ff.net.0.proj.bias": state_dict["image_proj"]["proj.0.bias"], + "ff.net.2.weight": state_dict["image_proj"]["proj.2.weight"], + "ff.net.2.bias": state_dict["image_proj"]["proj.2.bias"], + "norm.weight": state_dict["image_proj"]["proj.3.weight"], + "norm.bias": state_dict["image_proj"]["proj.3.bias"], + } + ) image_projection.load_state_dict(image_proj_state_dict) + del image_proj_state_dict else: # IP-Adapter Plus diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index bdd2930d20f9..73abc9869230 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -461,6 +461,18 @@ def forward(self, image_embeds: torch.FloatTensor): return image_embeds +class MLPProjection(nn.Module): + def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): + super().__init__() + from .attention import FeedForward + + self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu") + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds: torch.FloatTensor): + return self.norm(self.ff(image_embeds)) + + class CombinedTimestepLabelEmbeddings(nn.Module): def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): super().__init__() diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index 7c6349ce2600..ff93ecaf003b 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -182,6 +182,25 @@ def test_inpainting(self): assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + def test_text_to_image_full_face(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + pipeline = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin") + pipeline.set_ip_adapter_scale(0.7) + + inputs = self.get_dummy_inputs() + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array( + [0.1706543, 0.1303711, 0.12573242, 0.21777344, 0.14550781, 0.14038086, 0.40820312, 0.41455078, 0.42529297] + ) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + @slow @require_torch_gpu From e3d76c47f993bcbdc92ef89bc6ec0ee90bcffa8b Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 7 Dec 2023 11:35:28 -1000 Subject: [PATCH 78/95] Fix a bug in `add_noise` function (#6085) * fix * copies --------- Co-authored-by: yiyixuxu --- src/diffusers/schedulers/scheduling_deis_multistep.py | 11 ++++++++++- .../schedulers/scheduling_dpmsolver_multistep.py | 11 ++++++++++- .../scheduling_dpmsolver_multistep_inverse.py | 11 ++++++++++- .../schedulers/scheduling_dpmsolver_singlestep.py | 11 ++++++++++- .../schedulers/scheduling_unipc_multistep.py | 11 ++++++++++- 5 files changed, 50 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 7cf6a9b33b37..bd44d2444154 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -734,7 +734,16 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + step_indices = [] + for timestep in timesteps: + index_candidates = (schedule_timesteps == timestep).nonzero() + if len(index_candidates) == 0: + step_index = len(schedule_timesteps) - 1 + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + step_indices.append(step_index) sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index beab985e3350..086505c5052b 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -896,7 +896,16 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + step_indices = [] + for timestep in timesteps: + index_candidates = (schedule_timesteps == timestep).nonzero() + if len(index_candidates) == 0: + step_index = len(schedule_timesteps) - 1 + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + step_indices.append(step_index) sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 61d6810ce286..cfb53c943cea 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -891,7 +891,16 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + step_indices = [] + for timestep in timesteps: + index_candidates = (schedule_timesteps == timestep).nonzero() + if len(index_candidates) == 0: + step_index = len(schedule_timesteps) - 1 + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + step_indices.append(step_index) sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 0f1175472f3e..7e8149ab55c4 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -897,7 +897,16 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + step_indices = [] + for timestep in timesteps: + index_candidates = (schedule_timesteps == timestep).nonzero() + if len(index_candidates) == 0: + step_index = len(schedule_timesteps) - 1 + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + step_indices.append(step_index) sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 1d58ab5259ef..eaa6273e2768 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -828,7 +828,16 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + step_indices = [] + for timestep in timesteps: + index_candidates = (schedule_timesteps == timestep).nonzero() + if len(index_candidates) == 0: + step_index = len(schedule_timesteps) - 1 + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + step_indices.append(step_index) sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): From 472c3974ad28dc5efeeac2a06ce95d066259e55f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?apolin=C3=A1rio?= Date: Fri, 8 Dec 2023 11:45:27 +0000 Subject: [PATCH 79/95] [Advanced Diffusion Script] Add Widget default text (#6100) add widget --- .../train_dreambooth_lora_sdxl_advanced.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index f4b4e42c8b19..4d4ec523bfe6 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -157,6 +157,8 @@ def save_model_card( base_model: {base_model} instance_prompt: {instance_prompt} license: openrail++ +widget: + - text: '{validation_prompt if validation_prompt else instance_prompt}' --- """ From 373d39239d3d782088eb0a7377377a8b8e7b0dcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?apolin=C3=A1rio?= Date: Fri, 8 Dec 2023 14:56:35 +0000 Subject: [PATCH 80/95] [Advanced Training Script] Fix pipe example (#6106) --- .../train_dreambooth_lora_sdxl_advanced.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 4d4ec523bfe6..a46a1afcc145 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -135,8 +135,8 @@ def save_model_card( """ diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename="embeddings.safetensors", repo_type="model") state_dict = load_file(embedding_path) -pipeline.load_textual_inversion(state_dict["clip_l"], token=["", ""], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer) -pipeline.load_textual_inversion(state_dict["clip_g"], token=["", ""], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2) +pipeline.load_textual_inversion(state_dict["clip_l"], token=["", ""], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) +pipeline.load_textual_inversion(state_dict["clip_g"], token=["", ""], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) """ if token_abstraction_dict: for key, value in token_abstraction_dict.items(): From be46b6ebf7f79c9eddda37499ce68baca16ad13f Mon Sep 17 00:00:00 2001 From: Charchit Sharma Date: Sat, 9 Dec 2023 11:02:55 +0530 Subject: [PATCH 81/95] IP-Adapter for StableDiffusionControlNetImg2ImgPipeline (#5901) * adapter for StableDiffusionControlNetImg2ImgPipeline * fix-copies * fix-copies --------- Co-authored-by: Sayak Paul --- .../controlnet/pipeline_controlnet_img2img.py | 50 ++++++++++++++++--- .../controlnet/test_controlnet_img2img.py | 2 + 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index fa489941c987..037641bd820e 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -19,10 +19,10 @@ import PIL.Image import torch import torch.nn.functional as F -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers @@ -130,7 +130,7 @@ def prepare_image(image): class StableDiffusionControlNetImg2ImgPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin ): r""" Pipeline for image-to-image generation using Stable Diffusion with ControlNet guidance. @@ -140,7 +140,7 @@ class StableDiffusionControlNetImg2ImgPipeline( The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. @@ -166,7 +166,7 @@ class StableDiffusionControlNetImg2ImgPipeline( """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -180,6 +180,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() @@ -212,6 +213,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) @@ -468,6 +470,31 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -861,6 +888,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -922,6 +950,7 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -1053,6 +1082,11 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. Prepare image image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) @@ -1111,7 +1145,10 @@ def __call__( # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7.1 Create tensor stating which controlnets to keep + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 7.2 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ @@ -1171,6 +1208,7 @@ def __call__( cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py index 5a7f70eb488a..b4b67e6476f6 100644 --- a/tests/pipelines/controlnet/test_controlnet_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_img2img.py @@ -134,6 +134,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components @@ -273,6 +274,7 @@ def init_weights(m): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components From c7a87ca78495bf2afaef4779db3ac828fede1397 Mon Sep 17 00:00:00 2001 From: Aryan V S Date: Sun, 10 Dec 2023 21:19:14 +0530 Subject: [PATCH 82/95] IP adapter support for most pipelines (#5900) * support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py * support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py * support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py * update tests * support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py * support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py * support ip-adapter in src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py * support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py * support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py * support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py * revert changes to sd_attend_and_excite and sd_upscale * make style * fix broken tests * update ip-adapter implementation to latest * apply suggestions from review --------- Co-authored-by: YiYi Xu Co-authored-by: Sayak Paul --- .../pipeline_latent_consistency_img2img.py | 51 +++++++++++-- .../pipeline_latent_consistency_text2img.py | 53 ++++++++++++-- ...eline_stable_diffusion_instruct_pix2pix.py | 64 ++++++++++++++-- .../pipeline_stable_diffusion_ldm3d.py | 55 ++++++++++++-- .../pipeline_stable_diffusion_panorama.py | 60 +++++++++++++-- .../pipeline_stable_diffusion_sag.py | 73 +++++++++++++++++-- .../pipeline_stable_diffusion_safe.py | 61 ++++++++++++++-- .../test_latent_consistency_models.py | 1 + .../test_latent_consistency_models_img2img.py | 1 + ...st_stable_diffusion_instruction_pix2pix.py | 1 + .../test_stable_diffusion_ldm3d.py | 1 + .../test_stable_diffusion_panorama.py | 1 + .../test_stable_diffusion_sag.py | 1 + 13 files changed, 380 insertions(+), 43 deletions(-) diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index ed29a939388f..63a54f5aa666 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -20,11 +20,11 @@ import PIL.Image import torch -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import LCMScheduler from ...utils import ( @@ -129,7 +129,7 @@ def retrieve_timesteps( class LatentConsistencyModelImg2ImgPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for image-to-image generation using a latent consistency model. @@ -142,6 +142,7 @@ class LatentConsistencyModelImg2ImgPipeline( - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): @@ -166,7 +167,7 @@ class LatentConsistencyModelImg2ImgPipeline( """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"] @@ -179,6 +180,7 @@ def __init__( scheduler: LCMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: Optional[CLIPVisionModelWithProjection] = None, requires_safety_checker: bool = True, ): super().__init__() @@ -191,6 +193,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) if safety_checker is None and requires_safety_checker: @@ -449,6 +452,31 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -647,6 +675,7 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -695,6 +724,8 @@ def __call__( prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -758,6 +789,12 @@ def __call__( device = self._execution_device # do_classifier_free_guidance = guidance_scale > 1.0 + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + # 3. Encode input prompt lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None @@ -815,6 +852,9 @@ def __call__( # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None) + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + # 8. LCM Multistep Sampling Loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) @@ -829,6 +869,7 @@ def __call__( timestep_cond=w_embedding, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py index c8f1d647c15b..54d5a2ec989d 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py @@ -19,11 +19,11 @@ from typing import Any, Callable, Dict, List, Optional, Union import torch -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection -from ...image_processor import VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import LCMScheduler from ...utils import ( @@ -107,7 +107,7 @@ def retrieve_timesteps( class LatentConsistencyModelPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-to-image generation using a latent consistency model. @@ -120,6 +120,7 @@ class LatentConsistencyModelPipeline( - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): @@ -144,7 +145,7 @@ class LatentConsistencyModelPipeline( """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"] @@ -157,6 +158,7 @@ def __init__( scheduler: LCMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: Optional[CLIPVisionModelWithProjection] = None, requires_safety_checker: bool = True, ): super().__init__() @@ -185,6 +187,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -433,6 +436,31 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -581,6 +609,7 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -629,6 +658,8 @@ def __call__( prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -697,6 +728,12 @@ def __call__( device = self._execution_device # do_classifier_free_guidance = guidance_scale > 1.0 + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + # 3. Encode input prompt lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None @@ -748,6 +785,9 @@ def __call__( # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None) + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + # 8. LCM MultiStep Sampling Loop: num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) @@ -762,6 +802,7 @@ def __call__( timestep_cond=w_embedding, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index d922803858b0..b0021c5a3e63 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -18,11 +18,11 @@ import numpy as np import PIL.Image import torch -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, logging from ...utils.torch_utils import randn_tensor @@ -72,7 +72,9 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): +class StableDiffusionInstructPix2PixPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin +): r""" Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion). @@ -83,6 +85,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): @@ -105,7 +108,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "image_latents"] @@ -118,6 +121,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: Optional[CLIPVisionModelWithProjection] = None, requires_safety_checker: bool = True, ): super().__init__() @@ -146,6 +150,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -166,6 +171,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, @@ -213,6 +219,8 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -293,6 +301,16 @@ def __call__( self._guidance_scale = guidance_scale self._image_guidance_scale = image_guidance_scale + device = self._execution_device + + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([image_embeds, negative_image_embeds, negative_image_embeds]) + if image is None: raise ValueError("`image` input cannot be undefined.") @@ -367,6 +385,9 @@ def __call__( # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 8.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) @@ -383,7 +404,11 @@ def __call__( # predict the noise residual noise_pred = self.unet( - scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False + scaled_latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, )[0] # Hack: @@ -598,11 +623,36 @@ def _encode_prompt( # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] + # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]) return prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py index f410c08a3bbe..ee9335a2bb01 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py @@ -19,11 +19,11 @@ import numpy as np import PIL.Image import torch -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection -from ...image_processor import VaeImageProcessorLDM3D -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...image_processor import PipelineImageInput, VaeImageProcessorLDM3D +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -82,7 +82,7 @@ class LDM3DPipelineOutput(BaseOutput): class StableDiffusionLDM3DPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-to-image and 3D generation using LDM3D. @@ -95,6 +95,7 @@ class StableDiffusionLDM3DPipeline( - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): @@ -117,7 +118,7 @@ class StableDiffusionLDM3DPipeline( """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] def __init__( @@ -129,6 +130,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: Optional[CLIPVisionModelWithProjection], requires_safety_checker: bool = True, ): super().__init__() @@ -157,6 +159,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor) @@ -410,6 +413,31 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None @@ -529,6 +557,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -573,6 +602,8 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -622,6 +653,14 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, @@ -659,6 +698,9 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -673,6 +715,7 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index ff6a66ab57c9..bcc063499459 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -16,11 +16,11 @@ from typing import Any, Callable, Dict, List, Optional, Union import torch -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection -from ...image_processor import VaeImageProcessor -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import DDIMScheduler from ...utils import ( @@ -59,13 +59,19 @@ """ -class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): +class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin): r""" Pipeline for text-to-image generation using MultiDiffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. @@ -87,7 +93,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] def __init__( @@ -99,6 +105,7 @@ def __init__( scheduler: DDIMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: Optional[CLIPVisionModelWithProjection] = None, requires_safety_checker: bool = True, ): super().__init__() @@ -127,6 +134,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -363,6 +371,31 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -529,6 +562,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -578,6 +612,8 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -632,6 +668,14 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None @@ -681,6 +725,9 @@ def __call__( # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + # 8. Denoising loop # Each denoising step also includes refinement of the latents with respect to the # views. @@ -743,6 +790,7 @@ def __call__( t, encoder_hidden_states=prompt_embeds_input, cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, ).sample # perform guidance diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index 68652e977c5d..792a3c40b33d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -17,11 +17,11 @@ import torch import torch.nn.functional as F -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection -from ...image_processor import VaeImageProcessor -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -98,13 +98,17 @@ def __call__( # Modified to get self-attention guidance scale in this paper (https://arxiv.org/pdf/2210.00939.pdf) as an input -class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin): +class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. @@ -126,7 +130,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] def __init__( @@ -138,6 +142,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: Optional[CLIPVisionModelWithProjection] = None, requires_safety_checker: bool = True, ): super().__init__() @@ -150,6 +155,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -386,6 +392,31 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -519,6 +550,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -565,6 +597,8 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -618,6 +652,14 @@ def __call__( # `sag_scale = 0` means no self-attention guidance do_self_attention_guidance = sag_scale > 0.0 + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, @@ -655,6 +697,10 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_uncond_kwargs = {"image_embeds": negative_image_embeds} if ip_adapter_image is not None else None + # 7. Denoising loop store_processor = CrossAttnStoreProcessor() self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor @@ -680,6 +726,7 @@ def get_map_size(module, input, output): t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, ).sample # perform guidance @@ -703,7 +750,12 @@ def get_map_size(module, input, output): ) uncond_emb, _ = prompt_embeds.chunk(2) # forward and give guidance - degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample + degraded_pred = self.unet( + degraded_latents, + t, + encoder_hidden_states=uncond_emb, + added_cond_kwargs=added_uncond_kwargs, + ).sample noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) else: # DDIM-like prediction of x0 @@ -715,7 +767,12 @@ def get_map_size(module, input, output): pred_x0, cond_attn, map_size, t, self.pred_epsilon(latents, noise_pred, t) ) # forward and give guidance - degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample + degraded_pred = self.unet( + degraded_latents, + t, + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + ).sample noise_pred += sag_scale * (noise_pred - degraded_pred) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index eb24cbfd947b..fdc7844a7e08 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -5,10 +5,12 @@ import numpy as np import torch from packaging import version -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...configuration_utils import FrozenDict -from ...models import AutoencoderKL, UNet2DConditionModel +from ...image_processor import PipelineImageInput +from ...loaders import IPAdapterMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, logging from ...utils.torch_utils import randn_tensor @@ -20,13 +22,16 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class StableDiffusionPipelineSafe(DiffusionPipeline): +class StableDiffusionPipelineSafe(DiffusionPipeline, IPAdapterMixin): r""" Pipeline based on the [`StableDiffusionPipeline`] for text-to-image generation using Safe Latent Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). + The pipeline also inherits the following loading methods: + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. @@ -48,7 +53,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] def __init__( self, @@ -59,6 +64,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: SafeStableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: Optional[CLIPVisionModelWithProjection] = None, requires_safety_checker: bool = True, ): super().__init__() @@ -140,6 +146,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self._safety_text_concept = safety_concept self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) @@ -467,6 +474,31 @@ def perform_safety_guidance( noise_guidance = noise_guidance - noise_guidance_safety return noise_guidance, safety_momentum + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + @torch.no_grad() def __call__( self, @@ -480,6 +512,7 @@ def __call__( eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -521,6 +554,8 @@ def __call__( Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -588,6 +623,17 @@ def __call__( if not enable_safety_guidance: warnings.warn("Safety checker disabled!") + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if do_classifier_free_guidance: + if enable_safety_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds, image_embeds]) + else: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance @@ -613,6 +659,9 @@ def __call__( # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + safety_momentum = None num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -627,7 +676,9 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs + ).sample # perform guidance if do_classifier_free_guidance: diff --git a/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py b/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py index 174d9b6de9f8..5d33b45c0973 100644 --- a/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py +++ b/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py @@ -87,6 +87,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, "requires_safety_checker": False, } return components diff --git a/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py b/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py index f9410ffe640a..5b4e2b191f53 100644 --- a/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py +++ b/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py @@ -97,6 +97,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, "requires_safety_checker": False, } return components diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py index 69b36cb3bb8a..461bd14a43d8 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py @@ -108,6 +108,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_ldm3d.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_ldm3d.py index e7e98c52d92c..cfdb0c57bbdb 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_ldm3d.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_ldm3d.py @@ -93,6 +93,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py index 657608df8b98..ad86a53f5ee9 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py @@ -91,6 +91,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py index 6eae1ce4d371..8b789408f5cb 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py @@ -93,6 +93,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components From 47abcf6bf5716a32d79287421e97a8d5b6dc5c0c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 20 Dec 2023 10:12:14 +0530 Subject: [PATCH 83/95] fix: lora_alpha --- examples/consistency_distillation/train_lcm_distill_lora_sdxl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 73e56293f372..cd36ad1b8918 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -762,6 +762,7 @@ def main(args): # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. lora_config = LoraConfig( r=args.lora_rank, + lora_alpha=args.lora_rank, target_modules=[ "to_q", "to_k", From b7c0f95ff7b86258368b4a380fb83294da15e5b8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 20 Dec 2023 10:13:12 +0530 Subject: [PATCH 84/95] make vae casting conditional/ --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index cd36ad1b8918..00648de1335c 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -755,7 +755,10 @@ def main(args): # Move unet, vae and text_encoder to device and cast to weight_dtype # The VAE is in float32 to avoid NaN losses. unet.to(accelerator.device, dtype=weight_dtype) - vae.to(accelerator.device, dtype=torch.float32) + if args.pretrained_vae_model_name_or_path is None: + vae.to(accelerator.device, dtype=torch.float32) + else: + vae.to(accelerator.device, dtype=weight_dtype) text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) From 7a1d6c9077c1f902a4defea221afbc13a67fe2e4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 20 Dec 2023 10:18:20 +0530 Subject: [PATCH 85/95] param upcasting --- .../train_lcm_distill_lora_sdxl.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 00648de1335c..4ac9094b177e 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -785,6 +785,13 @@ def main(args): ) unet.add_adapter(lora_config) + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + for param in unet.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + # Also move the alpha and sigma noise schedules to accelerator.device. alpha_schedule = alpha_schedule.to(accelerator.device) sigma_schedule = sigma_schedule.to(accelerator.device) @@ -855,11 +862,7 @@ def load_model_hook(models, input_dir): optimizer_class = torch.optim.AdamW # 12. Optimizer creation - params_to_optimize = [] - for param in unet.parameters(): - if param.requires_grad: - param.data = param.to(torch.float32) - params_to_optimize.append(param) + params_to_optimize = filter(lambda p: p.requires_grad, unet.parameters()) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, From 87f87a70631d01e765d2e8cdbc1078a6e2908312 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 20 Dec 2023 10:37:07 +0530 Subject: [PATCH 86/95] propagate comments from https://github.com/huggingface/diffusers/pull/6145 Co-authored-by: dg845 --- .../train_lcm_distill_lora_sdxl.py | 132 ++++++++++++------ 1 file changed, 88 insertions(+), 44 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 4ac9094b177e..abf8fc62adcb 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -199,19 +199,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling= # Compare LCMScheduler.step, Step 4 -def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): +def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": - sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) - alphas = extract_into_tensor(alphas, timesteps, sample.shape) pred_x_0 = (sample - sigmas * model_output) / alphas + elif prediction_type == "sample": + pred_x_0 = model_output elif prediction_type == "v_prediction": - pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output + pred_x_0 = alphas * sample - sigmas * model_output else: - raise ValueError(f"Prediction type {prediction_type} currently not supported.") + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) return pred_x_0 +# Based on step 4 in DDIMScheduler.step +def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + if prediction_type == "epsilon": + pred_epsilon = model_output + elif prediction_type == "sample": + pred_epsilon = (sample - alphas * model_output) / sigmas + elif prediction_type == "v_prediction": + pred_epsilon = alphas * model_output + sigmas * sample + else: + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) + + return pred_epsilon + + def extract_into_tensor(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) @@ -676,16 +700,17 @@ def main(args): args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) - # The scheduler calculates the alpha and sigma schedule for us + # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + # Initialize the DDIM ODE solver for distillation. solver = DDIMSolver( noise_scheduler.alphas_cumprod.numpy(), timesteps=noise_scheduler.config.num_train_timesteps, ddim_timesteps=args.num_ddim_timesteps, ) - # 2. Load tokenizers from SD-XL checkpoint. + # 2. Load tokenizers from SDXL checkpoint. tokenizer_one = AutoTokenizer.from_pretrained( args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False ) @@ -693,7 +718,7 @@ def main(args): args.pretrained_teacher_model, subfolder="tokenizer_2", revision=args.teacher_revision, use_fast=False ) - # 3. Load text encoders from SD-XL checkpoint. + # 3. Load text encoders from SDXL checkpoint. # import correct text encoder classes text_encoder_cls_one = import_model_class_from_model_name_or_path( args.pretrained_teacher_model, args.teacher_revision @@ -709,7 +734,7 @@ def main(args): args.pretrained_teacher_model, subfolder="text_encoder_2", revision=args.teacher_revision ) - # 4. Load VAE from SD-XL checkpoint (or more stable VAE) + # 4. Load VAE from SDXL checkpoint (or more stable VAE) vae_path = ( args.pretrained_teacher_model if args.pretrained_vae_model_name_or_path is None @@ -726,7 +751,7 @@ def main(args): text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) - # 7. Create online (`unet`) student U-Net. + # 7. Create online student U-Net. unet = UNet2DConditionModel.from_pretrained( args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision ) @@ -743,7 +768,7 @@ def main(args): f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" ) - # 9. Handle mixed precision and device placement + # 8. Handle mixed precision and device placement # For mixed precision training we cast all non-trainable weigths to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 @@ -762,7 +787,7 @@ def main(args): text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) - # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. + # 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. lora_config = LoraConfig( r=args.lora_rank, lora_alpha=args.lora_rank, @@ -1007,7 +1032,7 @@ def compute_time_ids(original_size, crops_coords_top_left): compute_embeddings_fn = functools.partial(compute_embeddings, text_encoders=text_encoders, tokenizers=tokenizers) - # 14. LR Scheduler creation + # 15. LR Scheduler creation # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1027,7 +1052,7 @@ def compute_time_ids(original_size, crops_coords_top_left): num_training_steps=args.max_train_steps * accelerator.num_processes, ) - # 15. Prepare for training + # 16. Prepare for training # Prepare everything with our `accelerator`. unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler @@ -1046,7 +1071,7 @@ def compute_time_ids(original_size, crops_coords_top_left): tracker_config = dict(vars(args)) accelerator.init_trackers(args.tracker_project_name, config=tracker_config) - # 16. Train! + # 17. Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") @@ -1098,6 +1123,7 @@ def compute_time_ids(original_size, crops_coords_top_left): for epoch in range(first_epoch, args.num_train_epochs): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): + # 1. Load and process the image and text conditioning pixel_values, text, orig_size, crop_coords = ( batch["pixel_values"], batch["captions"], @@ -1118,44 +1144,43 @@ def compute_time_ids(original_size, crops_coords_top_left): if args.pretrained_vae_model_name_or_path is None: latents = latents.to(weight_dtype) - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) + # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias. + # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...] bsz = latents.shape[0] - - # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias. topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() start_timesteps = solver.ddim_timesteps[index] timesteps = start_timesteps - topk timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) - # Get boundary scalings for start_timesteps and (end) timesteps. + # 3. Get boundary scalings for start_timesteps and (end) timesteps. c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each + # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + noise = torch.randn_like(latents) noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) - # Sample a random guidance scale w from U[w_min, w_max] and embed it + # 5. Sample a random guidance scale w from U[w_min, w_max] + # Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min w = w.reshape(bsz, 1, 1, 1) w = w.to(device=latents.device, dtype=latents.dtype) - # Prepare prompt embeds and unet_added_conditions + # 6. Prepare prompt embeds and unet_added_conditions prompt_embeds = encoded_text.pop("prompt_embeds") - # Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k} + # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps) noise_pred = unet( noisy_model_input, start_timesteps, encoder_hidden_states=prompt_embeds, added_cond_kwargs=encoded_text, ).sample - - pred_x_0 = predicted_origin( + pred_x_0 = get_predicted_original_sample( noise_pred, start_timesteps, noisy_model_input, @@ -1165,20 +1190,28 @@ def compute_time_ids(original_size, crops_coords_top_left): ) model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 - # Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after - # noisy_latents with both the conditioning embedding c and unconditional embedding 0 - # Get teacher model prediction on noisy_latents and conditional embedding - # Notice that we're disabling the adapter layers within the `unet` and then it becomes a - # regular teacher. This way, we don't have to separately initialize a teacher UNet. + # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the + # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these + # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE + # solver timestep. unet.disable_adapters() with torch.no_grad(): + # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c cond_teacher_output = unet( noisy_model_input, start_timesteps, encoder_hidden_states=prompt_embeds, added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, ).sample - cond_pred_x0 = predicted_origin( + cond_pred_x0 = get_predicted_original_sample( + cond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + cond_pred_noise = get_predicted_noise( cond_teacher_output, start_timesteps, noisy_model_input, @@ -1187,11 +1220,10 @@ def compute_time_ids(original_size, crops_coords_top_left): sigma_schedule, ) - # Create uncond embeds for classifier free guidance + # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0 uncond_prompt_embeds = torch.zeros_like(prompt_embeds) uncond_pooled_prompt_embeds = torch.zeros_like(encoded_text["text_embeds"]) uncond_added_conditions = copy.deepcopy(encoded_text) - # Get teacher model prediction on noisy_latents and unconditional embedding uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds uncond_teacher_output = unet( noisy_model_input, @@ -1199,7 +1231,15 @@ def compute_time_ids(original_size, crops_coords_top_left): encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()}, ).sample - uncond_pred_x0 = predicted_origin( + uncond_pred_x0 = get_predicted_original_sample( + uncond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + uncond_pred_noise = get_predicted_noise( uncond_teacher_output, start_timesteps, noisy_model_input, @@ -1208,16 +1248,20 @@ def compute_time_ids(original_size, crops_coords_top_left): sigma_schedule, ) - # Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) + # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise) + # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) - pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) - x_prev = solver.ddim_step(pred_x0, pred_noise, index) - x_prev = x_prev.to(unet.dtype) + pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise) + # 4. Run one step of the ODE solver to estimate the next point x_prev on the + # augmented PF-ODE trajectory (solving backward in time) + # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0. + x_prev = solver.ddim_step(pred_x0, pred_noise, index).to(unet.dtype) # re-enable unet adapters unet.enable_adapters() - # Get target LCM prediction on x_prev, w, c, t_n + # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) + # Note that we do not use a separate target network for LCM-LoRA distillation. with torch.no_grad(): target_noise_pred = unet( x_prev, @@ -1225,7 +1269,7 @@ def compute_time_ids(original_size, crops_coords_top_left): encoder_hidden_states=prompt_embeds, added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, ).sample - pred_x_0 = predicted_origin( + pred_x_0 = get_predicted_original_sample( target_noise_pred, timesteps, x_prev, @@ -1235,7 +1279,7 @@ def compute_time_ids(original_size, crops_coords_top_left): ) target = c_skip * x_prev + c_out * pred_x_0 - # Calculate loss + # 10. Calculate loss if args.loss_type == "l2": loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") elif args.loss_type == "huber": @@ -1243,7 +1287,7 @@ def compute_time_ids(original_size, crops_coords_top_left): torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c ) - # Backpropagate on the online student model (`unet`) + # 11. Backpropagate on the online student model (`unet`) (only LoRA) accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm) From 4c7e983bb5929320bab08d70333eeb93f047de40 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 26 Dec 2023 11:39:28 +0100 Subject: [PATCH 87/95] [Peft] fix saving / loading when unet is not "unet" (#6046) * [Peft] fix saving / loading when unet is not "unet" * Update src/diffusers/loaders/lora.py Co-authored-by: Sayak Paul * undo stablediffusion-xl changes * use unet_name to get unet for lora helpers * use unet_name --------- Co-authored-by: Sayak Paul --- src/diffusers/loaders/ip_adapter.py | 6 ++-- src/diffusers/loaders/lora.py | 46 ++++++++++++++++++----------- 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 158bde436374..3df0492380e5 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -149,9 +149,11 @@ def load_ip_adapter( self.feature_extractor = CLIPImageProcessor() # load ip-adapter into unet - self.unet._load_ip_adapter_weights(state_dict) + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet._load_ip_adapter_weights(state_dict) def set_ip_adapter_scale(self, scale): - for attn_processor in self.unet.attn_processors.values(): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for attn_processor in unet.attn_processors.values(): if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): attn_processor.scale = scale diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index fc50c52e412b..2ceff743daca 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -912,10 +912,10 @@ def pack_weights(layers, prefix): ) if unet_lora_layers: - state_dict.update(pack_weights(unet_lora_layers, "unet")) + state_dict.update(pack_weights(unet_lora_layers, cls.unet_name)) if text_encoder_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + state_dict.update(pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) if transformer_lora_layers: state_dict.update(pack_weights(transformer_lora_layers, "transformer")) @@ -975,6 +975,8 @@ def unload_lora_weights(self): >>> ... ``` """ + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + if not USE_PEFT_BACKEND: if version.parse(__version__) > version.parse("0.23"): logger.warn( @@ -982,13 +984,13 @@ def unload_lora_weights(self): "you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT." ) - for _, module in self.unet.named_modules(): + for _, module in unet.named_modules(): if hasattr(module, "set_lora_layer"): module.set_lora_layer(None) else: - recurse_remove_peft_layers(self.unet) - if hasattr(self.unet, "peft_config"): - del self.unet.peft_config + recurse_remove_peft_layers(unet) + if hasattr(unet, "peft_config"): + del unet.peft_config # Safe to call the following regardless of LoRA. self._remove_text_encoder_monkey_patch() @@ -1027,7 +1029,8 @@ def fuse_lora( ) if fuse_unet: - self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing) + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.fuse_lora(lora_scale, safe_fusing=safe_fusing) if USE_PEFT_BACKEND: from peft.tuners.tuners_utils import BaseTunerLayer @@ -1080,13 +1083,14 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet if unfuse_unet: if not USE_PEFT_BACKEND: - self.unet.unfuse_lora() + unet.unfuse_lora() else: from peft.tuners.tuners_utils import BaseTunerLayer - for module in self.unet.modules(): + for module in unet.modules(): if isinstance(module, BaseTunerLayer): module.unmerge() @@ -1202,8 +1206,9 @@ def set_adapters( adapter_names: Union[List[str], str], adapter_weights: Optional[List[float]] = None, ): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet # Handle the UNET - self.unet.set_adapters(adapter_names, adapter_weights) + unet.set_adapters(adapter_names, adapter_weights) # Handle the Text Encoder if hasattr(self, "text_encoder"): @@ -1216,7 +1221,8 @@ def disable_lora(self): raise ValueError("PEFT backend is required for this method.") # Disable unet adapters - self.unet.disable_lora() + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.disable_lora() # Disable text encoder adapters if hasattr(self, "text_encoder"): @@ -1229,7 +1235,8 @@ def enable_lora(self): raise ValueError("PEFT backend is required for this method.") # Enable unet adapters - self.unet.enable_lora() + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.enable_lora() # Enable text encoder adapters if hasattr(self, "text_encoder"): @@ -1251,7 +1258,8 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): adapter_names = [adapter_names] # Delete unet adapters - self.unet.delete_adapters(adapter_names) + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.delete_adapters(adapter_names) for adapter_name in adapter_names: # Delete text encoder adapters @@ -1284,8 +1292,8 @@ def get_active_adapters(self) -> List[str]: from peft.tuners.tuners_utils import BaseTunerLayer active_adapters = [] - - for module in self.unet.modules(): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for module in unet.modules(): if isinstance(module, BaseTunerLayer): active_adapters = module.active_adapters break @@ -1309,8 +1317,9 @@ def get_list_adapters(self) -> Dict[str, List[str]]: if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"): set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys()) - if hasattr(self, "unet") and hasattr(self.unet, "peft_config"): - set_adapters["unet"] = list(self.unet.peft_config.keys()) + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + if hasattr(self, self.unet_name) and hasattr(unet, "peft_config"): + set_adapters[self.unet_name] = list(self.unet.peft_config.keys()) return set_adapters @@ -1331,7 +1340,8 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, from peft.tuners.tuners_utils import BaseTunerLayer # Handle the UNET - for unet_module in self.unet.modules(): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for unet_module in unet.modules(): if isinstance(unet_module, BaseTunerLayer): for adapter_name in adapter_names: unet_module.lora_A[adapter_name].to(device) From 0bb9cf0216e501632677895de6574532092282b5 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 26 Dec 2023 11:40:04 +0100 Subject: [PATCH 88/95] [Wuerstchen] fix fp16 training and correct lora args (#6245) fix fp16 training Co-authored-by: Sayak Paul --- .../text_to_image/train_text_to_image_lora_prior.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py index 1e67f05abe7a..f1f6b3215201 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py @@ -527,9 +527,17 @@ def deepspeed_zero_init_disabled_context_manager(): # lora attn processor prior_lora_config = LoraConfig( - r=args.rank, target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"] + r=args.rank, + lora_alpha=args.rank, + target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"], ) + # Add adapter and make sure the trainable params are in float32. prior.add_adapter(prior_lora_config) + if args.mixed_precision == "fp16": + for param in prior.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): From 11659a6f74b5187f601eeeeeb6f824dda73d0627 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 26 Dec 2023 19:13:49 +0530 Subject: [PATCH 89/95] [docs] fix: animatediff docs (#6339) fix: animatediff docs --- .../pipelines/animatediff/pipeline_animatediff.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 0dab722e51a8..b0fe790c2222 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -33,7 +33,14 @@ LMSDiscreteScheduler, PNDMScheduler, ) -from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -47,7 +54,7 @@ >>> from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler >>> from diffusers.utils import export_to_gif - >>> adapter = MotionAdapter.from_pretrained("diffusers/motion-adapter") + >>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2") >>> pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter) >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False) >>> output = pipe(prompt="A corgi walking in the park") @@ -533,6 +540,7 @@ def prepare_latents( return latents @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, From f645b87eb3f243e825d70ab0a5d4391a20b78fa3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 26 Dec 2023 20:26:55 +0530 Subject: [PATCH 90/95] add: note about the new script in readme_sdxl. --- .../consistency_distillation/README_sdxl.md | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/examples/consistency_distillation/README_sdxl.md b/examples/consistency_distillation/README_sdxl.md index 16d32bcc571e..d3abaa4ce175 100644 --- a/examples/consistency_distillation/README_sdxl.md +++ b/examples/consistency_distillation/README_sdxl.md @@ -111,4 +111,38 @@ accelerate launch train_lcm_distill_lora_sdxl_wds.py \ --report_to=wandb \ --seed=453645634 \ --push_to_hub \ -``` \ No newline at end of file +``` + +We provide another version for LCM LoRA SDXL that follows best practices of `peft` and leverages the `datasets` library for quick experimentation. The script doesn't load two UNets unlike `train_lcm_distill_lora_sdxl_wds.py` which reduces the memory requirements quite a bit. + +Below is an example training command that trains an LCM LoRA on the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions): + +```bash +export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0" +export DATASET_NAME="lambdalabs/pokemon-blip-captions" +export VAE_PATH="madebyollin/sdxl-vae-fp16-fix" + +accelerate launch train_lcm_distill_lora_sdxl.py \ + --pretrained_teacher_model=${MODEL_NAME} \ + --pretrained_vae_model_name_or_path=${VAE_PATH} \ + --output_dir="pokemons-lora-lcm-sdxl" \ + --mixed_precision="fp16" \ + --dataset_name=$DATASET_NAME \ + --resolution=1024 \ + --train_batch_size=24 \ + --gradient_accumulation_steps=1 \ + --gradient_checkpointing \ + --use_8bit_adam \ + --lora_rank=64 \ + --learning_rate=1e-4 \ + --report_to="wandb" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=3000 \ + --checkpointing_steps=500 \ + --validation_steps=50 \ + --seed="0" \ + --report_to="wandb" \ + --push_to_hub +``` + From fd64acf9a9cf6367421bd4c561abc0e67d1ee39f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 26 Dec 2023 20:27:36 +0530 Subject: [PATCH 91/95] Revert "[Peft] fix saving / loading when unet is not "unet" (#6046)" This reverts commit 4c7e983bb5929320bab08d70333eeb93f047de40. --- src/diffusers/loaders/ip_adapter.py | 6 ++-- src/diffusers/loaders/lora.py | 46 +++++++++++------------------ 2 files changed, 20 insertions(+), 32 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 3df0492380e5..158bde436374 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -149,11 +149,9 @@ def load_ip_adapter( self.feature_extractor = CLIPImageProcessor() # load ip-adapter into unet - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - unet._load_ip_adapter_weights(state_dict) + self.unet._load_ip_adapter_weights(state_dict) def set_ip_adapter_scale(self, scale): - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - for attn_processor in unet.attn_processors.values(): + for attn_processor in self.unet.attn_processors.values(): if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): attn_processor.scale = scale diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 2ceff743daca..fc50c52e412b 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -912,10 +912,10 @@ def pack_weights(layers, prefix): ) if unet_lora_layers: - state_dict.update(pack_weights(unet_lora_layers, cls.unet_name)) + state_dict.update(pack_weights(unet_lora_layers, "unet")) if text_encoder_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) if transformer_lora_layers: state_dict.update(pack_weights(transformer_lora_layers, "transformer")) @@ -975,8 +975,6 @@ def unload_lora_weights(self): >>> ... ``` """ - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - if not USE_PEFT_BACKEND: if version.parse(__version__) > version.parse("0.23"): logger.warn( @@ -984,13 +982,13 @@ def unload_lora_weights(self): "you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT." ) - for _, module in unet.named_modules(): + for _, module in self.unet.named_modules(): if hasattr(module, "set_lora_layer"): module.set_lora_layer(None) else: - recurse_remove_peft_layers(unet) - if hasattr(unet, "peft_config"): - del unet.peft_config + recurse_remove_peft_layers(self.unet) + if hasattr(self.unet, "peft_config"): + del self.unet.peft_config # Safe to call the following regardless of LoRA. self._remove_text_encoder_monkey_patch() @@ -1029,8 +1027,7 @@ def fuse_lora( ) if fuse_unet: - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - unet.fuse_lora(lora_scale, safe_fusing=safe_fusing) + self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing) if USE_PEFT_BACKEND: from peft.tuners.tuners_utils import BaseTunerLayer @@ -1083,14 +1080,13 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet if unfuse_unet: if not USE_PEFT_BACKEND: - unet.unfuse_lora() + self.unet.unfuse_lora() else: from peft.tuners.tuners_utils import BaseTunerLayer - for module in unet.modules(): + for module in self.unet.modules(): if isinstance(module, BaseTunerLayer): module.unmerge() @@ -1206,9 +1202,8 @@ def set_adapters( adapter_names: Union[List[str], str], adapter_weights: Optional[List[float]] = None, ): - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet # Handle the UNET - unet.set_adapters(adapter_names, adapter_weights) + self.unet.set_adapters(adapter_names, adapter_weights) # Handle the Text Encoder if hasattr(self, "text_encoder"): @@ -1221,8 +1216,7 @@ def disable_lora(self): raise ValueError("PEFT backend is required for this method.") # Disable unet adapters - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - unet.disable_lora() + self.unet.disable_lora() # Disable text encoder adapters if hasattr(self, "text_encoder"): @@ -1235,8 +1229,7 @@ def enable_lora(self): raise ValueError("PEFT backend is required for this method.") # Enable unet adapters - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - unet.enable_lora() + self.unet.enable_lora() # Enable text encoder adapters if hasattr(self, "text_encoder"): @@ -1258,8 +1251,7 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): adapter_names = [adapter_names] # Delete unet adapters - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - unet.delete_adapters(adapter_names) + self.unet.delete_adapters(adapter_names) for adapter_name in adapter_names: # Delete text encoder adapters @@ -1292,8 +1284,8 @@ def get_active_adapters(self) -> List[str]: from peft.tuners.tuners_utils import BaseTunerLayer active_adapters = [] - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - for module in unet.modules(): + + for module in self.unet.modules(): if isinstance(module, BaseTunerLayer): active_adapters = module.active_adapters break @@ -1317,9 +1309,8 @@ def get_list_adapters(self) -> Dict[str, List[str]]: if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"): set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys()) - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - if hasattr(self, self.unet_name) and hasattr(unet, "peft_config"): - set_adapters[self.unet_name] = list(self.unet.peft_config.keys()) + if hasattr(self, "unet") and hasattr(self.unet, "peft_config"): + set_adapters["unet"] = list(self.unet.peft_config.keys()) return set_adapters @@ -1340,8 +1331,7 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, from peft.tuners.tuners_utils import BaseTunerLayer # Handle the UNET - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - for unet_module in unet.modules(): + for unet_module in self.unet.modules(): if isinstance(unet_module, BaseTunerLayer): for adapter_name in adapter_names: unet_module.lora_A[adapter_name].to(device) From 121567b07c7842c368e0bf182764267dad1fda4b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 26 Dec 2023 20:27:45 +0530 Subject: [PATCH 92/95] Revert "[Wuerstchen] fix fp16 training and correct lora args (#6245)" This reverts commit 0bb9cf0216e501632677895de6574532092282b5. --- .../text_to_image/train_text_to_image_lora_prior.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py index f1f6b3215201..1e67f05abe7a 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py @@ -527,17 +527,9 @@ def deepspeed_zero_init_disabled_context_manager(): # lora attn processor prior_lora_config = LoraConfig( - r=args.rank, - lora_alpha=args.rank, - target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"], + r=args.rank, target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"] ) - # Add adapter and make sure the trainable params are in float32. prior.add_adapter(prior_lora_config) - if args.mixed_precision == "fp16": - for param in prior.parameters(): - # only upcast trainable parameters (LoRA) into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): From c24626ae3bb5ff2d885ca02cc991fe0b1692ad65 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 26 Dec 2023 20:27:57 +0530 Subject: [PATCH 93/95] Revert "[docs] fix: animatediff docs (#6339)" This reverts commit 11659a6f74b5187f601eeeeeb6f824dda73d0627. --- .../pipelines/animatediff/pipeline_animatediff.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index b0fe790c2222..0dab722e51a8 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -33,14 +33,7 @@ LMSDiscreteScheduler, PNDMScheduler, ) -from ...utils import ( - USE_PEFT_BACKEND, - BaseOutput, - logging, - replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, -) +from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -54,7 +47,7 @@ >>> from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler >>> from diffusers.utils import export_to_gif - >>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2") + >>> adapter = MotionAdapter.from_pretrained("diffusers/motion-adapter") >>> pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter) >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False) >>> output = pipe(prompt="A corgi walking in the park") @@ -540,7 +533,6 @@ def prepare_latents( return latents @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, From 4c689b29350c7ffcff85c2430c6fdf734010da34 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 26 Dec 2023 20:29:38 +0530 Subject: [PATCH 94/95] remove tokenize_prompt(). --- .../train_lcm_distill_lora_sdxl.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index abf8fc62adcb..0b706b64dfba 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -242,18 +242,6 @@ def extract_into_tensor(a, t, x_shape): return out.reshape(b, *((1,) * (len(x_shape) - 1))) -def tokenize_prompt(tokenizer, prompt): - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - return text_input_ids - - def import_model_class_from_model_name_or_path( pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" ): From 1b49fb92dd675088fb94806ba8f7f1e363c275ec Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 26 Dec 2023 20:31:35 +0530 Subject: [PATCH 95/95] assistive comments around enable_adapters() and diable_adapters(). --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 0b706b64dfba..2733eb146cd3 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -1182,6 +1182,8 @@ def compute_time_ids(original_size, crops_coords_top_left): # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # solver timestep. + + # With the adapters disabled, the `unet` is the regular teacher model. unet.disable_adapters() with torch.no_grad(): # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c @@ -1245,7 +1247,7 @@ def compute_time_ids(original_size, crops_coords_top_left): # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0. x_prev = solver.ddim_step(pred_x0, pred_noise, index).to(unet.dtype) - # re-enable unet adapters + # re-enable unet adapters to turn the `unet` into a student unet. unet.enable_adapters() # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)