From 13f121245143ddf07e60ef1cb7c08d1cf912cb38 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Fri, 14 Jul 2023 00:17:05 +1000 Subject: [PATCH 01/67] Support instruction pix2pix sdxl --- .../train_instruct_pix2pix_xl.py | 1051 +++++++++++++++++ ...ne_stable_diffusion_xl_instruct_pix2pix.py | 963 +++++++++++++++ 2 files changed, 2014 insertions(+) create mode 100644 examples/instruct_pix2pix/train_instruct_pix2pix_xl.py create mode 100644 src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py new file mode 100644 index 000000000000..7e708fe8e104 --- /dev/null +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py @@ -0,0 +1,1051 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script to fine-tune Stable Diffusion for InstructPix2Pix.""" + +import argparse +import logging +import math +import os +import shutil +from pathlib import Path + +import accelerate +import datasets +import numpy as np +import PIL +import requests +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +import diffusers +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.18.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + +DATASET_NAME_MAPPING = { + "fusing/instructpix2pix-1000-samples": ("file_name", "edited_image", "edit_prompt"), +} +WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"] + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--original_image_column", + type=str, + default="input_image", + help="The column of the dataset containing the original image on which edits where made.", + ) + parser.add_argument( + "--edited_image_column", + type=str, + default="edited_image", + help="The column of the dataset containing the edited image.", + ) + parser.add_argument( + "--edit_prompt_column", + type=str, + default="edit_prompt", + help="The column of the dataset containing the edit instruction.", + ) + parser.add_argument( + "--val_image_url", + type=str, + default=None, + help="URL to the original image that you would like to edit (used during inference for debugging purposes).", + ) + parser.add_argument( + "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference." + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=1, + help=( + "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="instruct-pix2pix-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=256, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--conditioning_dropout_prob", + type=float, + default=None, + help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.", + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +def convert_to_np(image, resolution): + if isinstance(image, str): + image = PIL.Image.open(image) + image = image.convert("RGB").resize((resolution, resolution)) + return np.array(image).transpose(2, 0, 1) + + +def download_image(url): + image = PIL.Image.open(requests.get(url, stream=True).raw) + image = PIL.ImageOps.exif_transpose(image) + image = image.convert("RGB") + return image + + +def main(): + args = parse_args() + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision + ) + import pdb; pdb.set_trace() + + # InstructPix2Pix uses an additional image for conditioning. To accommodate that, + # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is + # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized + # from the pre-trained checkpoints. For the extra channels added to the first layer, they are + # initialized to zero. + logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.") + # in_channels = 8 + in_channels = 4 + out_channels = unet.conv_in.out_channels + unet.register_to_config(in_channels=in_channels) + + with torch.no_grad(): + new_conv_in = nn.Conv2d( + in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding + ) + new_conv_in.weight.zero_() + new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) + unet.conv_in = new_conv_in + + # Freeze vae and text_encoder + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + + # Create EMA for the unet. + if args.use_ema: + ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if args.use_ema: + ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + optimizer = optimizer_cls( + unet.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/main/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.original_image_column is None: + original_image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + original_image_column = args.original_image_column + if original_image_column not in column_names: + raise ValueError( + f"--original_image_column' value '{args.original_image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.edit_prompt_column is None: + edit_prompt_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + edit_prompt_column = args.edit_prompt_column + if edit_prompt_column not in column_names: + raise ValueError( + f"--edit_prompt_column' value '{args.edit_prompt_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.edited_image_column is None: + edited_image_column = dataset_columns[2] if dataset_columns is not None else column_names[2] + else: + edited_image_column = args.edited_image_column + if edited_image_column not in column_names: + raise ValueError( + f"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images. + def tokenize_captions(captions): + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids + + # Preprocessing the datasets. + train_transforms = transforms.Compose( + [ + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + ] + ) + + def preprocess_images(examples): + original_images = np.concatenate( + [convert_to_np(image, args.resolution) for image in examples[original_image_column]] + ) + edited_images = np.concatenate( + [convert_to_np(image, args.resolution) for image in examples[edited_image_column]] + ) + # We need to ensure that the original and the edited images undergo the same + # augmentation transforms. + images = np.concatenate([original_images, edited_images]) + images = torch.tensor(images) + images = 2 * (images / 255) - 1 + return train_transforms(images) + + def preprocess_train(examples): + # Preprocess images. + preprocessed_images = preprocess_images(examples) + # Since the original and edited images were concatenated before + # applying the transformations, we need to separate them and reshape + # them accordingly. + original_images, edited_images = preprocessed_images.chunk(2) + original_images = original_images.reshape(-1, 3, args.resolution, args.resolution) + edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution) + + # Collate the preprocessed images into the `examples`. + examples["original_pixel_values"] = original_images + examples["edited_pixel_values"] = edited_images + + # Preprocess the captions. + captions = list(examples[edit_prompt_column]) + examples["input_ids"] = tokenize_captions(captions) + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + original_pixel_values = torch.stack([example["original_pixel_values"] for example in examples]) + original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float() + edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples]) + edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float() + input_ids = torch.stack([example["input_ids"] for example in examples]) + return { + "original_pixel_values": original_pixel_values, + "edited_pixel_values": edited_pixel_values, + "input_ids": input_ids, + } + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + if args.use_ema: + ema_unet.to(accelerator.device) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu and cast to weight_dtype + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder_2.to(accelerator.device, dtype=weight_dtype) + text_encoders = [text_encoder, text_encoder_2] + vae.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("instruct-pix2pix", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(unet): + # We want to learn the denoising process w.r.t the edited images which + # are conditioned on the original image (which was edited) and the edit instruction. + # So, first, convert images to latent space. + latents = vae.encode(batch["edited_pixel_values"].to(weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning. + prompt_embeds_list = [] + for text_encoder in text_encoders[1:]: + prompt_embeds = text_encoder(batch["input_ids"], output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, 1, 1) + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + encoder_hidden_states = torch.concat(prompt_embeds_list, dim=-1) + + # Get the additional image embedding for conditioning. + # Instead of getting a diagonal Gaussian here, we simply take the mode. + original_image_embeds = vae.encode(batch["original_pixel_values"].to(weight_dtype)).latent_dist.mode() + + # Conditioning dropout to support classifier-free guidance during inference. For more details + # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800. + if args.conditioning_dropout_prob is not None: + random_p = torch.rand(bsz, device=latents.device, generator=generator) + # Sample masks for the edit prompts. + prompt_mask = random_p < 2 * args.conditioning_dropout_prob + prompt_mask = prompt_mask.reshape(bsz, 1, 1) + # Final text conditioning. + null_conditioning = text_encoder(tokenize_captions([""]).to(accelerator.device))[0] + encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states) + + # Sample masks for the original images. + image_mask_dtype = original_image_embeds.dtype + image_mask = 1 - ( + (random_p >= args.conditioning_dropout_prob).to(image_mask_dtype) + * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype) + ) + image_mask = image_mask.reshape(bsz, 1, 1, 1) + # Final image conditioning. + original_image_embeds = image_mask * original_image_embeds + + # Concatenate the `original_image_embeds` with the `noisy_latents`. + concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1) + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + ### Begin SDXL + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view( + bs_embed * 1, -1 + ) + add_text_embeds = pooled_prompt_embeds + + crops_coords_top_left = (0, 0) + target_size = (args.resolution, args.resolution) + original_size = original_image_embeds.shape[-2:] + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids], dtype=encoder_hidden_states.dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=encoder_hidden_states.dtype) + add_time_ids = add_time_ids.to(encoder_hidden_states.device).repeat(args.train_batch_size, 1) + ### End SDXL + + # Predict the noise residual and compute loss + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + import pdb; pdb.set_trace() + model_pred = unet(concatenated_noisy_latents[:, :4, :, :], timesteps, encoder_hidden_states, cross_attention_kwargs=None, added_cond_kwargs=added_cond_kwargs, return_dict=False).sample + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema_unet.step(unet.parameters()) + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if ( + (args.val_image_url is not None) + and (args.validation_prompt is not None) + and (epoch % args.validation_epochs == 0) + ): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + # The models need unwrapping because for compatibility in distributed training mode. + pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + vae=accelerator.unwrap_model(vae), + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + original_image = download_image(args.val_image_url) + edited_images = [] + with torch.autocast( + str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" + ): + for _ in range(args.num_validation_images): + edited_images.append( + pipeline( + args.validation_prompt, + image=original_image, + num_inference_steps=20, + image_guidance_scale=1.5, + guidance_scale=7, + generator=generator, + ).images[0] + ) + + for tracker in accelerator.trackers: + if tracker.name == "wandb": + wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES) + for edited_image in edited_images: + wandb_table.add_data( + wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt + ) + tracker.log({"validation": wandb_table}) + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(unet.parameters()) + + del pipeline + torch.cuda.empty_cache() + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + if args.use_ema: + ema_unet.copy_to(unet.parameters()) + + pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=accelerator.unwrap_model(text_encoder), + vae=accelerator.unwrap_model(vae), + unet=unet, + revision=args.revision, + ) + pipeline.save_pretrained(args.output_dir) + + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + if args.validation_prompt is not None: + edited_images = [] + pipeline = pipeline.to(accelerator.device) + with torch.autocast(str(accelerator.device).replace(":0", "")): + for _ in range(args.num_validation_images): + edited_images.append( + pipeline( + args.validation_prompt, + image=original_image, + num_inference_steps=20, + image_guidance_scale=1.5, + guidance_scale=7, + generator=generator, + ).images[0] + ) + + for tracker in accelerator.trackers: + if tracker.name == "wandb": + wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES) + for edited_image in edited_images: + wandb_table.add_data( + wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt + ) + tracker.log({"test": wandb_table}) + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py new file mode 100644 index 000000000000..7bfcea0860c0 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -0,0 +1,963 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionXLPipelineOutput +from .watermark import StableDiffusionXLWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" + + >>> init_image = load_image(url).convert("RGB") + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, image=init_image).images[0] + ``` +""" + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + _optional_components = ["tokenizer", "text_encoder"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.watermark = StableDiffusionXLWatermarker() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder_2, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.text_encoder is not None: + cpu_offload(self.text_encoder, device) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + model_sequence = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + model_sequence.extend([self.unet, self.vae]) + + hook = None + for cpu_offloaded_model in model_sequence: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix + def check_inputs( + self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copy from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix + def prepare_image_latents( + self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + image_latents = image + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.mode() + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + + return image_latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + num_inference_steps: int = 100, + guidance_scale: float = 7.5, + image_guidance_scale: float = 1.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): + The image(s) to modify with the pipeline. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + image_guidance_scale (`float`, *optional*, defaults to 1.5): + Image guidance scale is to push the generated image towards the inital image `image`. Image guidance + scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to + generate images that are closely linked to the source image `image`, usually at the expense of lower + image quality. This pipeline requires a value of at least `1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + TODO + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + aesthetic_score (`float`, *optional*, defaults to 6.0): + TODO + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + TDOO + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 and image_guidance_scale >= 1.0 + # check if scheduler is in sigmas space + scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare Image latents + image_latents = self.prepare_image_latents( + image, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + do_classifier_free_guidance, + generator, + ) + + height, width = image_latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 8. Check that shapes of latents and image match the UNet channels + num_channels_image = image_latents.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents + num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 10. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype=prompt_embeds.dtype, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 11. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance. + # The latents are expanded 3 times because for pix2pix the guidance\ + # is applied for both the text and the input image. + latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents + + # concat latents, image_latents in the channel dimension + scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + scaled_latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # Hack: + # For karras style schedulers the model does classifer free guidance using the + # predicted_original_sample instead of the noise_pred. So we need to compute the + # predicted_original_sample here if we are using a karras style scheduler. + if scheduler_is_in_sigma_space: + step_index = (self.scheduler.timesteps == t).nonzero()[0].item() + sigma = self.scheduler.sigmas[step_index] + noise_pred = latent_model_input - sigma * noise_pred + + # perform guidance + if do_classifier_free_guidance: + noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_image) + + image_guidance_scale * (noise_pred_image - noise_pred_uncond) + ) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # Hack: + # For karras style schedulers the model does classifer free guidance using the + # predicted_original_sample instead of the noise_pred. But the scheduler.step function + # expects the noise_pred and computes the predicted_original_sample internally. So we + # need to overwrite the noise_pred here such that the value of the computed + # predicted_original_sample is correct. + if scheduler_is_in_sigma_space: + noise_pred = (noise_pred - latents) / (-sigma) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # make sure the VAE is in float32 mode, as it overflows in float16 + self.vae.to(dtype=torch.float32) + + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(latents.dtype) + self.vae.decoder.conv_in.to(latents.dtype) + self.vae.decoder.mid_block.to(latents.dtype) + else: + latents = latents.float() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + image = self.watermark.apply_watermark(image) + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) From 67a401c9b1a60bc4600487bd08c8e7bc1d6ea6f1 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Fri, 14 Jul 2023 11:06:26 +1000 Subject: [PATCH 02/67] Support instruction pix2pix sdxl --- ...line_stable_diffusion_xl_instruct_pix2pix.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 7bfcea0860c0..33556d060fa2 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -356,7 +356,7 @@ def encode_prompt( # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - + prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) @@ -855,6 +855,10 @@ def __call__( add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + add_text_embeds = torch.concat((add_text_embeds, add_text_embeds.clone()[0].unsqueeze(0)), dim=0) + add_time_ids = torch.concat((add_time_ids, add_time_ids.clone()[0].unsqueeze(0)), dim=0) + prompt_embeds = torch.concat((prompt_embeds, prompt_embeds.clone()[0].unsqueeze(0)), dim=0) + prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) @@ -874,14 +878,7 @@ def __call__( # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - noise_pred = self.unet( - scaled_latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] + noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False)[0] # Hack: # For karras style schedulers the model does classifer free guidance using the @@ -960,4 +957,4 @@ def __call__( if not return_dict: return (image,) - return StableDiffusionXLPipelineOutput(images=image) + return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file From 4514be5c7654c4aa0558329de78e6f9cad2adde5 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Sat, 15 Jul 2023 22:37:05 +1000 Subject: [PATCH 03/67] Support instruction pix2pix sdxl --- .../train_instruct_pix2pix_xl.py | 2 -- ...line_stable_diffusion_xl_instruct_pix2pix.py | 17 +++++++++++++---- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py index 7e708fe8e104..c3285074b0bb 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py @@ -448,7 +448,6 @@ def main(): unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision ) - import pdb; pdb.set_trace() # InstructPix2Pix uses an additional image for conditioning. To accommodate that, # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is @@ -878,7 +877,6 @@ def collate_fn(examples): # Predict the noise residual and compute loss added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - import pdb; pdb.set_trace() model_pred = unet(concatenated_noisy_latents[:, :4, :, :], timesteps, encoder_hidden_states, cross_attention_kwargs=None, added_cond_kwargs=added_cond_kwargs, return_dict=False).sample loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 33556d060fa2..b11d7cf27ba9 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -528,7 +528,8 @@ def prepare_image_latents( raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - + + tmp_img = torch.load('/home/users/u5689359/gitRepo_mill/Lycium/tmp_image.pt').to('cuda') image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt @@ -536,6 +537,11 @@ def prepare_image_latents( if image.shape[1] == 4: image_latents = image else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -859,11 +865,12 @@ def __call__( add_time_ids = torch.concat((add_time_ids, add_time_ids.clone()[0].unsqueeze(0)), dim=0) prompt_embeds = torch.concat((prompt_embeds, prompt_embeds.clone()[0].unsqueeze(0)), dim=0) - prompt_embeds = prompt_embeds.to(device) - add_text_embeds = add_text_embeds.to(device) + prompt_embeds = prompt_embeds.to(device).to(torch.float32) + add_text_embeds = add_text_embeds.to(device).to(torch.float32) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # 11. Denoising loop + self.unet = self.unet.to(torch.float32) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -878,6 +885,8 @@ def __call__( # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + tar_idx = 2; added_cond_kwargs_tmp = {"text_embeds": add_text_embeds[tar_idx].unsqueeze(0), "time_ids": add_time_ids[tar_idx].unsqueeze(0)} + # self.unet(scaled_latent_model_input[tar_idx].unsqueeze(0), t, encoder_hidden_states=prompt_embeds[tar_idx].unsqueeze(0), cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs_tmp, return_dict=False)[0] noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False)[0] # Hack: @@ -957,4 +966,4 @@ def __call__( if not return_dict: return (image,) - return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file + return StableDiffusionXLPipelineOutput(images=image) From 187fc3610de6c67666d3912836435ca1cfc6f200 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Sun, 16 Jul 2023 15:17:36 +1000 Subject: [PATCH 04/67] Support instruction pix2pix sdxl --- .../train_instruct_pix2pix_xl.py | 83 +++++++++++++------ 1 file changed, 56 insertions(+), 27 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py index c3285074b0bb..3ae320692884 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py @@ -52,6 +52,8 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline +from PIL import Image + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.18.0.dev0") @@ -435,10 +437,13 @@ def main(): # Load scheduler, tokenizer and models. noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - tokenizer = CLIPTokenizer.from_pretrained( + tokenizer_1 = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision ) - text_encoder = CLIPTextModel.from_pretrained( + tokenizer_2 = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision + ) + text_encoder_1 = CLIPTextModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( @@ -455,8 +460,7 @@ def main(): # from the pre-trained checkpoints. For the extra channels added to the first layer, they are # initialized to zero. logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.") - # in_channels = 8 - in_channels = 4 + in_channels = 8 out_channels = unet.conv_in.out_channels unet.register_to_config(in_channels=in_channels) @@ -470,7 +474,8 @@ def main(): # Freeze vae and text_encoder vae.requires_grad_(False) - text_encoder.requires_grad_(False) + text_encoder_1.requires_grad_(False) + text_encoder_2.requires_grad_(False) # Create EMA for the unet. if args.use_ema: @@ -614,9 +619,9 @@ def load_model_hook(models, input_dir): # Preprocessing the datasets. # We need to tokenize input captions and transform the images. - def tokenize_captions(captions): - inputs = tokenizer( - captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + def tokenize_captions(captions, a_tokenizer): + inputs = a_tokenizer( + captions, max_length=a_tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ) return inputs.input_ids @@ -658,7 +663,8 @@ def preprocess_train(examples): # Preprocess the captions. captions = list(examples[edit_prompt_column]) - examples["input_ids"] = tokenize_captions(captions) + examples["input_ids"] = tokenize_captions(captions, tokenizer_1) + examples["input_ids_2"] = tokenize_captions(captions, tokenizer_2) return examples with accelerator.main_process_first(): @@ -673,10 +679,12 @@ def collate_fn(examples): edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples]) edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = torch.stack([example["input_ids"] for example in examples]) + input_ids_2 = torch.stack([example["input_ids_2"] for example in examples]) return { "original_pixel_values": original_pixel_values, "edited_pixel_values": edited_pixel_values, "input_ids": input_ids, + "input_ids_2": input_ids_2, } # DataLoaders creation: @@ -719,9 +727,9 @@ def collate_fn(examples): weight_dtype = torch.bfloat16 # Move text_encode and vae to gpu and cast to weight_dtype - text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder_1.to(accelerator.device, dtype=weight_dtype) text_encoder_2.to(accelerator.device, dtype=weight_dtype) - text_encoders = [text_encoder, text_encoder_2] + text_encoders = [text_encoder_1, text_encoder_2] vae.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. @@ -792,6 +800,7 @@ def collate_fn(examples): # We want to learn the denoising process w.r.t the edited images which # are conditioned on the original image (which was edited) and the edit instruction. # So, first, convert images to latent space. + # tmp_pixel_value = torch.load('xl_image.pt').to('cuda') latents = vae.encode(batch["edited_pixel_values"].to(weight_dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor @@ -806,21 +815,24 @@ def collate_fn(examples): # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - # Get the text embedding for conditioning. + ### Begin encoder prompt prompt_embeds_list = [] - for text_encoder in text_encoders[1:]: - prompt_embeds = text_encoder(batch["input_ids"], output_hidden_states=True) + for input_ids, text_encoder in zip((batch["input_ids"], batch["input_ids_2"]), text_encoders): + prompt_embeds = text_encoder(input_ids, output_hidden_states=True) + # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, 1, 1) prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) - + encoder_hidden_states = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(bs_embed, -1) + ### End encoder prompt # Get the additional image embedding for conditioning. # Instead of getting a diagonal Gaussian here, we simply take the mode. @@ -834,7 +846,17 @@ def collate_fn(examples): prompt_mask = random_p < 2 * args.conditioning_dropout_prob prompt_mask = prompt_mask.reshape(bsz, 1, 1) # Final text conditioning. - null_conditioning = text_encoder(tokenize_captions([""]).to(accelerator.device))[0] + ### Begin: Get null conditioning + null_conditioning_list = [] + for a_tokenizer, a_text_encoder in zip((tokenizer_1, tokenizer_2), (text_encoder_1, text_encoder_2)): + null_conditioning_list.append( + a_text_encoder( + tokenize_captions([""], a_tokenizer=a_tokenizer).to(accelerator.device), + output_hidden_states=True + ).hidden_states[-2] + ) + ### End: Get null conditioning + null_conditioning = torch.concat(null_conditioning_list, dim=-1) encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states) # Sample masks for the original images. @@ -859,14 +881,10 @@ def collate_fn(examples): raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") ### Begin SDXL - bs_embed = pooled_prompt_embeds.shape[0] - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view( - bs_embed * 1, -1 - ) add_text_embeds = pooled_prompt_embeds crops_coords_top_left = (0, 0) - target_size = (args.resolution, args.resolution) + target_size = (512, 512) original_size = original_image_embeds.shape[-2:] add_time_ids = list(original_size + crops_coords_top_left + target_size) add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) @@ -876,8 +894,12 @@ def collate_fn(examples): ### End SDXL # Predict the noise residual and compute loss + # tmp_prompt_embeds = torch.load('xl_prompt_embeds.pt').to('cuda') + # tmp_concatenated_noisy_latents = torch.load('xl_latent_model_input.pt').to('cuda') added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - model_pred = unet(concatenated_noisy_latents[:, :4, :, :], timesteps, encoder_hidden_states, cross_attention_kwargs=None, added_cond_kwargs=added_cond_kwargs, return_dict=False).sample + + model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs).sample + # model_pred = unet(concatenated_noisy_latents, timesteps, tmp_prompt_embeds, added_cond_kwargs=added_cond_kwargs).sample loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Gather the losses across all processes for logging (if we use distributed training). @@ -952,7 +974,10 @@ def collate_fn(examples): pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), + text_encoder=accelerator.unwrap_model(text_encoder_1), + text_encoder_2=accelerator.unwrap_model(text_encoder_2), + tokenizer=accelerator.unwrap_model(tokenizer_1), + tokenizer_2=accelerator.unwrap_model(tokenizer_2), vae=accelerator.unwrap_model(vae), revision=args.revision, torch_dtype=weight_dtype, @@ -961,7 +986,8 @@ def collate_fn(examples): pipeline.set_progress_bar_config(disable=True) # run inference - original_image = download_image(args.val_image_url) + # original_image = download_image(args.val_image_url) + original_image = Image.open(args.val_image_url).convert("RGB") edited_images = [] with torch.autocast( str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" @@ -1002,7 +1028,10 @@ def collate_fn(examples): pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, - text_encoder=accelerator.unwrap_model(text_encoder), + text_encoder=accelerator.unwrap_model(text_encoder_1), + text_encoder_2=accelerator.unwrap_model(text_encoder_2), + tokenizer=accelerator.unwrap_model(tokenizer_1), + tokenizer_2=accelerator.unwrap_model(tokenizer_2), vae=accelerator.unwrap_model(vae), unet=unet, revision=args.revision, @@ -1046,4 +1075,4 @@ def collate_fn(examples): if __name__ == "__main__": - main() + main() \ No newline at end of file From ecdd293bef04b2a1fd466e023a2b96ebfe1d8558 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Sun, 16 Jul 2023 17:14:11 +1000 Subject: [PATCH 05/67] Support instruction pix2pix sdxl --- .../train_instruct_pix2pix_xl.py | 62 ++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py index 3ae320692884..572a4dad2ca3 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py @@ -144,7 +144,7 @@ def parse_args(): parser.add_argument( "--validation_epochs", type=int, - default=1, + default=100, help=( "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" " `args.validation_prompt` multiple times: `args.num_validation_images`." @@ -952,6 +952,61 @@ def collate_fn(examples): logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) + ### BEGIN: Perform validation every `validation_epochs` steps + if global_step % args.validation_epochs == 0 or global_step == 1: + if ( + (args.val_image_url is not None) + and (args.validation_prompt is not None) + ): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + + # create pipeline + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + + # The models need unwrapping because for compatibility in distributed training mode. + pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder_1), + text_encoder_2=accelerator.unwrap_model(text_encoder_2), + tokenizer=accelerator.unwrap_model(tokenizer_1), + tokenizer_2=accelerator.unwrap_model(tokenizer_2), + vae=accelerator.unwrap_model(vae), + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + # Save validation images + val_save_dir = os.path.join(args.output_dir, "validation_images") + if not os.path.exists(val_save_dir): + os.makedirs(val_save_dir) + + # original_image = download_image(args.val_image_url) + original_image = Image.open(args.val_image_url).convert("RGB") + with torch.autocast( + str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" + ): + for val_img_idx in range(args.num_validation_images): + a_val_img = pipeline( + args.validation_prompt, + image=original_image, + num_inference_steps=20, + image_guidance_scale=1.5, + guidance_scale=7, + generator=generator, + ).images[0] + a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png")) + ### END: Perform validation every `validation_epochs` steps + if global_step >= args.max_train_steps: break @@ -986,6 +1041,11 @@ def collate_fn(examples): pipeline.set_progress_bar_config(disable=True) # run inference + # Save validation images + val_save_dir = os.path.join(args.output_dir, "validation_images") + if not os.path.exists(val_save_dir): + os.makedirs(val_save_dir) + # original_image = download_image(args.val_image_url) original_image = Image.open(args.val_image_url).convert("RGB") edited_images = [] From 94df45cdd32c6999afec00984b6b0c194c9c2ce2 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Sun, 16 Jul 2023 19:37:12 +1000 Subject: [PATCH 06/67] Support instruction pix2pix sdxl --- .../train_instruct_pix2pix_xl.py | 50 +++++++++++++++---- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 44 ++++++++-------- 2 files changed, 62 insertions(+), 32 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py index 572a4dad2ca3..5c889824dac2 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py @@ -41,7 +41,8 @@ from packaging import version from torchvision import transforms from tqdm.auto import tqdm -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +# from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import AutoTokenizer, PretrainedConfig import diffusers from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel @@ -66,6 +67,26 @@ WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"] +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.") parser.add_argument( @@ -436,22 +457,31 @@ def main(): ).repo_id # Load scheduler, tokenizer and models. - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - tokenizer_1 = CLIPTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + tokenizer_1 = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False + ) + tokenizer_2 = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False + ) + text_encoder_cls_1 = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision ) - tokenizer_2 = CLIPTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision + text_encoder_cls_2 = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" ) - text_encoder_1 = CLIPTextModel.from_pretrained( + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder_1 = text_encoder_cls_1.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) - text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( + text_encoder_2 = text_encoder_cls_2.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) # InstructPix2Pix uses an additional image for conditioning. To accommodate that, @@ -459,7 +489,7 @@ def main(): # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized # from the pre-trained checkpoints. For the extra channels added to the first layer, they are # initialized to zero. - logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.") + logger.info("Initializing the XL InstructPix2Pix UNet from the pretrained UNet.") in_channels = 8 out_channels = unet.conv_in.out_channels unet.register_to_config(in_channels=in_channels) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index b11d7cf27ba9..aef1d22021f1 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -529,7 +529,6 @@ def prepare_image_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - tmp_img = torch.load('/home/users/u5689359/gitRepo_mill/Lycium/tmp_image.pt').to('cuda') image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt @@ -794,22 +793,26 @@ def __call__( ) # 4. Preprocess image - image = self.image_processor.preprocess(image) + image = self.image_processor.preprocess(image).to(device) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 6. Prepare Image latents - image_latents = self.prepare_image_latents( - image, - batch_size, - num_images_per_prompt, - prompt_embeds.dtype, - device, - do_classifier_free_guidance, - generator, - ) + # image_latents = self.prepare_image_latents( + # image, + # batch_size, + # num_images_per_prompt, + # prompt_embeds.dtype, + # device, + # do_classifier_free_guidance, + # generator, + # ) + image_latents = self.vae.encode(image).latent_dist.mode() + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([uncond_image_latents, image_latents], dim=0) height, width = image_latents.shape[-2:] height = height * self.vae_scale_factor @@ -861,9 +864,9 @@ def __call__( add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) - add_text_embeds = torch.concat((add_text_embeds, add_text_embeds.clone()[0].unsqueeze(0)), dim=0) - add_time_ids = torch.concat((add_time_ids, add_time_ids.clone()[0].unsqueeze(0)), dim=0) - prompt_embeds = torch.concat((prompt_embeds, prompt_embeds.clone()[0].unsqueeze(0)), dim=0) + # add_text_embeds = torch.concat((add_text_embeds, add_text_embeds.clone()[0].unsqueeze(0)), dim=0) + # add_time_ids = torch.concat((add_time_ids, add_time_ids.clone()[0].unsqueeze(0)), dim=0) + # prompt_embeds = torch.concat((prompt_embeds, prompt_embeds.clone()[0].unsqueeze(0)), dim=0) prompt_embeds = prompt_embeds.to(device).to(torch.float32) add_text_embeds = add_text_embeds.to(device).to(torch.float32) @@ -877,7 +880,8 @@ def __call__( # Expand the latents if we are doing classifier free guidance. # The latents are expanded 3 times because for pix2pix the guidance\ # is applied for both the text and the input image. - latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents + # latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # concat latents, image_latents in the channel dimension scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -885,7 +889,7 @@ def __call__( # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - tar_idx = 2; added_cond_kwargs_tmp = {"text_embeds": add_text_embeds[tar_idx].unsqueeze(0), "time_ids": add_time_ids[tar_idx].unsqueeze(0)} + # tar_idx = 2; added_cond_kwargs_tmp = {"text_embeds": add_text_embeds[tar_idx].unsqueeze(0), "time_ids": add_time_ids[tar_idx].unsqueeze(0)} # self.unet(scaled_latent_model_input[tar_idx].unsqueeze(0), t, encoder_hidden_states=prompt_embeds[tar_idx].unsqueeze(0), cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs_tmp, return_dict=False)[0] noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False)[0] @@ -900,12 +904,8 @@ def __call__( # perform guidance if do_classifier_free_guidance: - noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) - noise_pred = ( - noise_pred_uncond - + guidance_scale * (noise_pred_text - noise_pred_image) - + image_guidance_scale * (noise_pred_image - noise_pred_uncond) - ) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf From fb3bf004edb5a5d8b687c5145e0e186ba1217ff2 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Sun, 16 Jul 2023 23:24:10 +1000 Subject: [PATCH 07/67] Support instruction pix2pix sdxl --- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 871 ++++++++++++------ 1 file changed, 578 insertions(+), 293 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index aef1d22021f1..2d3ad8f8eb30 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -356,7 +356,7 @@ def encode_prompt( # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - + prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) @@ -457,167 +457,154 @@ def prepare_extra_step_kwargs(self, generator, eta): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix def check_inputs( - self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None - ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) +# Copyright 2023 The InstructPix2Pix Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. - def get_timesteps(self, num_inference_steps, strength, device): - # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) +import inspect +import warnings +from typing import Callable, List, Optional, Union - t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer - return timesteps, num_inference_steps - t_start +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, +) +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) +logger = logging.get_logger(__name__) # pylint: disable=invalid-name - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - # Copy from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix - def prepare_image_latents( - self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None - ): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - image = image.to(device=device, dtype=dtype) - batch_size = batch_size * num_images_per_prompt +def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion. - if image.shape[1] == 4: - image_latents = image - else: - # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.config.force_upcast: - image = image.float() - self.vae.to(dtype=torch.float32) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - if isinstance(generator, list): - image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = self.vae.encode(image).latent_dist.mode() + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: - # expand image_latents for batch_size - deprecation_message = ( - f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" - " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" - " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" - " your script to pass as many initial images as text prompts to suppress this warning." - ) - deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) - additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) - elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." - ) - else: - image_latents = torch.cat([image_latents], dim=0) + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] - if do_classifier_free_guidance: - uncond_image_latents = torch.zeros_like(image_latents) - image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] - return image_latents - - def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, ): - if self.config.requires_aesthetics_score: - add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) - add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) - else: - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) - - passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim - ) - expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + super().__init__() - if ( - expected_add_embed_dim > passed_add_embed_dim - and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) - elif ( - expected_add_embed_dim < passed_add_embed_dim - and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." - ) - elif expected_add_embed_dim != passed_add_embed_dim: + + if safety_checker is not None and feature_extractor is None: raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) - - return add_time_ids, add_neg_time_ids + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, @@ -639,19 +626,10 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - guidance_rescale: float = 0.0, - original_size: Tuple[int, int] = None, - crops_coords_top_left: Tuple[int, int] = (0, 0), - target_size: Tuple[int, int] = None, - aesthetic_score: float = 6.0, - negative_aesthetic_score: float = 2.5, ): r""" Function invoked when calling the pipeline for generation. @@ -660,9 +638,10 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): - The image(s) to modify with the pipeline. - num_inference_steps (`int`, *optional*, defaults to 50): + image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch which will be repainted according to `prompt`. Can also + accpet image latents as `image`, if passing latents directly, it will not be encoded again. + num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): @@ -670,7 +649,7 @@ def __call__( `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + usually at the expense of lower image quality. This pipeline requires a value of at least `1`. image_guidance_scale (`float`, *optional*, defaults to 1.5): Image guidance scale is to push the generated image towards the inital image `image`. Image guidance scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to @@ -678,14 +657,14 @@ def __call__( image quality. This pipeline requires a value of at least `1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): @@ -699,18 +678,11 @@ def __call__( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be @@ -718,42 +690,50 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - guidance_rescale (`float`, *optional*, defaults to 0.7): - Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of - [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). - Guidance rescale factor should fix overexposure when using zero terminal SNR. - original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): - TODO - crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): - TODO - target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): - TODO - aesthetic_score (`float`, *optional*, defaults to 6.0): - TODO - negative_aesthetic_score (`float`, *optional*, defaults to 2.5): - TDOO Examples: + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionInstructPix2PixPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" + + >>> image = download_image(img_url).resize((512, 512)) + + >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( + ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "make the mountains snowy" + >>> image = pipe(prompt=prompt, image=image).images[0] + ``` + Returns: - [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a - `tuple. When returning a tuple, the first element is a list with the generated images, and the second - element is a list of `bool`s denoting whether the corresponding generated image likely represents - "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. """ - # 1. Check inputs. Raise error if not correct + # 0. Check inputs self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) if image is None: raise ValueError("`image` input cannot be undefined.") - # 2. Define call parameters + # 1. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -762,7 +742,6 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @@ -770,16 +749,8 @@ def __call__( # check if scheduler is in sigmas space scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") - # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.encode_prompt( + # 2. Encode input prompt + prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, @@ -787,32 +758,25 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - lora_scale=text_encoder_lora_scale, ) - # 4. Preprocess image - image = self.image_processor.preprocess(image).to(device) + # 3. Preprocess image + image = self.image_processor.preprocess(image) - # 5. Prepare timesteps + # 4. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps - # 6. Prepare Image latents - # image_latents = self.prepare_image_latents( - # image, - # batch_size, - # num_images_per_prompt, - # prompt_embeds.dtype, - # device, - # do_classifier_free_guidance, - # generator, - # ) - image_latents = self.vae.encode(image).latent_dist.mode() - if do_classifier_free_guidance: - uncond_image_latents = torch.zeros_like(image_latents) - image_latents = torch.cat([uncond_image_latents, image_latents], dim=0) + # 5. Prepare Image latents + image_latents = self.prepare_image_latents( + image, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + do_classifier_free_guidance, + generator, + ) height, width = image_latents.shape[-2:] height = height * self.vae_scale_factor @@ -831,67 +795,37 @@ def __call__( latents, ) - # 8. Check that shapes of latents and image match the UNet channels + # 7. Check that shapes of latents and image match the UNet channels num_channels_image = image_latents.shape[1] if num_channels_latents + num_channels_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_image`: {num_channels_image} " - f" = {num_channels_latents + num_channels_image}. Please verify the config of" + f" = {num_channels_latents+num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) - # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - original_size = original_size or (height, width) - target_size = target_size or (height, width) - - # 10. Prepare added time ids & embeddings - add_text_embeds = pooled_prompt_embeds - add_time_ids, add_neg_time_ids = self._get_add_time_ids( - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - dtype=prompt_embeds.dtype, - ) - - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) - - # add_text_embeds = torch.concat((add_text_embeds, add_text_embeds.clone()[0].unsqueeze(0)), dim=0) - # add_time_ids = torch.concat((add_time_ids, add_time_ids.clone()[0].unsqueeze(0)), dim=0) - # prompt_embeds = torch.concat((prompt_embeds, prompt_embeds.clone()[0].unsqueeze(0)), dim=0) - - prompt_embeds = prompt_embeds.to(device).to(torch.float32) - add_text_embeds = add_text_embeds.to(device).to(torch.float32) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) - - # 11. Denoising loop - self.unet = self.unet.to(torch.float32) + # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Expand the latents if we are doing classifier free guidance. # The latents are expanded 3 times because for pix2pix the guidance\ # is applied for both the text and the input image. - # latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents # concat latents, image_latents in the channel dimension scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) # predict the noise residual - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - # tar_idx = 2; added_cond_kwargs_tmp = {"text_embeds": add_text_embeds[tar_idx].unsqueeze(0), "time_ids": add_time_ids[tar_idx].unsqueeze(0)} - # self.unet(scaled_latent_model_input[tar_idx].unsqueeze(0), t, encoder_hidden_states=prompt_embeds[tar_idx].unsqueeze(0), cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs_tmp, return_dict=False)[0] - noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False)[0] + noise_pred = self.unet( + scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False + )[0] # Hack: # For karras style schedulers the model does classifer free guidance using the @@ -904,12 +838,12 @@ def __call__( # perform guidance if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - if do_classifier_free_guidance and guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_image) + + image_guidance_scale * (noise_pred_image - noise_pred_uncond) + ) # Hack: # For karras style schedulers the model does classifer free guidance using the @@ -929,41 +863,392 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # make sure the VAE is in float32 mode, as it overflows in float16 - self.vae.to(dtype=torch.float32) - - use_torch_2_0_or_xformers = isinstance( - self.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - self.vae.post_quant_conv.to(latents.dtype) - self.vae.decoder.conv_in.to(latents.dtype) - self.vae.decoder.mid_block.to(latents.dtype) - else: - latents = latents.float() - if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents - return StableDiffusionXLPipelineOutput(images=image) + has_nsfw_concept = None - image = self.watermark.apply_watermark(image) - image = self.image_processor.postprocess(image, output_type=output_type) + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: - return (image,) + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_ prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] + prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]) + + return prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept - return StableDiffusionXLPipelineOutput(images=image) + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs( + self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image_latents( + self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + image_latents = image + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.mode() + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + + return image_latents From a31bdcf70cadd9b884294524de3172e358291ac4 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Sun, 16 Jul 2023 23:28:32 +1000 Subject: [PATCH 08/67] Support instruction pix2pix sdxl --- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 871 ++++++------------ 1 file changed, 293 insertions(+), 578 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 2d3ad8f8eb30..e7a36349cbeb 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -356,7 +356,7 @@ def encode_prompt( # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - + prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) @@ -457,154 +457,167 @@ def prepare_extra_step_kwargs(self, generator, eta): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix def check_inputs( -# Copyright 2023 The InstructPix2Pix Authors and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. + self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) -import inspect -import warnings -from typing import Callable, List, Optional, Union + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") -import numpy as np -import PIL -import torch -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) -from ...image_processor import VaeImageProcessor -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - PIL_INTERPOLATION, - deprecate, - is_accelerate_available, - is_accelerate_version, - logging, - randn_tensor, -) -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) -logger = logging.get_logger(__name__) # pylint: disable=invalid-name + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + return timesteps, num_inference_steps - t_start -def preprocess(image): - warnings.warn( - "The preprocess method is deprecated and will be removed in a future version. Please" - " use VaeImageProcessor.preprocess instead", - FutureWarning, - ) - if isinstance(image, torch.Tensor): - return image - elif isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - w, h = image[0].size - w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 - - image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = 2.0 * image - 1.0 - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - return image - - -class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): - r""" - Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion. + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) - In addition the pipeline inherits the following loading methods: - - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copy from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix + def prepare_image_latents( + self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) - as well as the following saving methods: - - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + batch_size = batch_size * num_images_per_prompt - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPImageProcessor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - _optional_components = ["safety_checker", "feature_extractor"] + if image.shape[1] == 4: + image_latents = image + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPImageProcessor, - requires_safety_checker: bool = True, - ): - super().__init__() + if isinstance(generator, list): + image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.mode() - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." ) - - if safety_checker is not None and feature_extractor is None: + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." ) + else: + image_latents = torch.cat([image_latents], dim=0) - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + + return image_latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.register_to_config(requires_safety_checker=requires_safety_checker) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, @@ -626,10 +639,19 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, ): r""" Function invoked when calling the pipeline for generation. @@ -638,10 +660,9 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): - `Image`, or tensor representing an image batch which will be repainted according to `prompt`. Can also - accpet image latents as `image`, if passing latents directly, it will not be encoded again. - num_inference_steps (`int`, *optional*, defaults to 100): + image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): + The image(s) to modify with the pipeline. + num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): @@ -649,7 +670,7 @@ def __call__( `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. This pipeline requires a value of at least `1`. + usually at the expense of lower image quality. image_guidance_scale (`float`, *optional*, defaults to 1.5): Image guidance scale is to push the generated image towards the inital image `image`. Image guidance scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to @@ -657,14 +678,14 @@ def __call__( image quality. This pipeline requires a value of at least `1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` - is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): @@ -678,11 +699,18 @@ def __call__( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be @@ -690,50 +718,42 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + TODO + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + aesthetic_score (`float`, *optional*, defaults to 6.0): + TODO + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + TDOO Examples: - ```py - >>> import PIL - >>> import requests - >>> import torch - >>> from io import BytesIO - - >>> from diffusers import StableDiffusionInstructPix2PixPipeline - - - >>> def download_image(url): - ... response = requests.get(url) - ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") - - - >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" - - >>> image = download_image(img_url).resize((512, 512)) - - >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( - ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16 - ... ) - >>> pipe = pipe.to("cuda") - - >>> prompt = "make the mountains snowy" - >>> image = pipe(prompt=prompt, image=image).images[0] - ``` - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ - # 0. Check inputs + # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) if image is None: raise ValueError("`image` input cannot be undefined.") - # 1. Define call parameters + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -742,6 +762,7 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @@ -749,8 +770,16 @@ def __call__( # check if scheduler is in sigmas space scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") - # 2. Encode input prompt - prompt_embeds = self._encode_prompt( + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -758,25 +787,32 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, ) - # 3. Preprocess image - image = self.image_processor.preprocess(image) + # 4. Preprocess image + image = self.image_processor.preprocess(image).to(device) - # 4. set timesteps + # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps - # 5. Prepare Image latents - image_latents = self.prepare_image_latents( - image, - batch_size, - num_images_per_prompt, - prompt_embeds.dtype, - device, - do_classifier_free_guidance, - generator, - ) + # 6. Prepare Image latents + # image_latents = self.prepare_image_latents( + # image, + # batch_size, + # num_images_per_prompt, + # prompt_embeds.dtype, + # device, + # do_classifier_free_guidance, + # generator, + # ) + image_latents = self.vae.encode(image).latent_dist.mode() + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([uncond_image_latents, image_latents], dim=0) height, width = image_latents.shape[-2:] height = height * self.vae_scale_factor @@ -795,37 +831,67 @@ def __call__( latents, ) - # 7. Check that shapes of latents and image match the UNet channels + # 8. Check that shapes of latents and image match the UNet channels num_channels_image = image_latents.shape[1] if num_channels_latents + num_channels_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_image`: {num_channels_image} " - f" = {num_channels_latents+num_channels_image}. Please verify the config of" + f" = {num_channels_latents + num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) - # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 9. Denoising loop + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 10. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype=prompt_embeds.dtype, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + # add_text_embeds = torch.concat((add_text_embeds, add_text_embeds.clone()[0].unsqueeze(0)), dim=0) + # add_time_ids = torch.concat((add_time_ids, add_time_ids.clone()[0].unsqueeze(0)), dim=0) + # prompt_embeds = torch.concat((prompt_embeds, prompt_embeds.clone()[0].unsqueeze(0)), dim=0) + + prompt_embeds = prompt_embeds.to(device).to(torch.float32) + add_text_embeds = add_text_embeds.to(device).to(torch.float32) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 11. Denoising loop + self.unet = self.unet.to(torch.float32) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Expand the latents if we are doing classifier free guidance. # The latents are expanded 3 times because for pix2pix the guidance\ # is applied for both the text and the input image. - latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents + # latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # concat latents, image_latents in the channel dimension scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) # predict the noise residual - noise_pred = self.unet( - scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False - )[0] + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + # tar_idx = 2; added_cond_kwargs_tmp = {"text_embeds": add_text_embeds[tar_idx].unsqueeze(0), "time_ids": add_time_ids[tar_idx].unsqueeze(0)} + # self.unet(scaled_latent_model_input[tar_idx].unsqueeze(0), t, encoder_hidden_states=prompt_embeds[tar_idx].unsqueeze(0), cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs_tmp, return_dict=False)[0] + noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False)[0] # Hack: # For karras style schedulers the model does classifer free guidance using the @@ -838,12 +904,12 @@ def __call__( # perform guidance if do_classifier_free_guidance: - noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) - noise_pred = ( - noise_pred_uncond - + guidance_scale * (noise_pred_text - noise_pred_image) - + image_guidance_scale * (noise_pred_image - noise_pred_uncond) - ) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # Hack: # For karras style schedulers the model does classifer free guidance using the @@ -863,392 +929,41 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) + # make sure the VAE is in float32 mode, as it overflows in float16 + self.vae.to(dtype=torch.float32) + + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(latents.dtype) + self.vae.decoder.conv_in.to(latents.dtype) + self.vae.decoder.mid_block.to(latents.dtype) + else: + latents = latents.float() + if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents - has_nsfw_concept = None + return StableDiffusionXLPipelineOutput(images=image) - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] - else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + image = self.watermark.apply_watermark(image) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): - from accelerate import cpu_offload - else: - raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: - cpu_offload(cpu_offloaded_model, device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - - def enable_model_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate import cpu_offload_with_hook - else: - raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - hook = None - for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: - _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - - if self.safety_checker is not None: - _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) - - # We'll offload the last model manually. - self.final_offload_hook = hook - - @property - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_ prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] - prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]) - - return prompt_embeds - - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is None: - has_nsfw_concept = None - else: - if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - else: - feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - return image, has_nsfw_concept + return (image,) - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def decode_latents(self, latents): - warnings.warn( - "The decode_latents method is deprecated and will be removed in a future version. Please" - " use VaeImageProcessor instead", - FutureWarning, - ) - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - def check_inputs( - self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None - ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - def prepare_image_latents( - self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None - ): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - image = image.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - - if image.shape[1] == 4: - image_latents = image - else: - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if isinstance(generator, list): - image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = self.vae.encode(image).latent_dist.mode() - - if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: - # expand image_latents for batch_size - deprecation_message = ( - f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" - " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" - " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" - " your script to pass as many initial images as text prompts to suppress this warning." - ) - deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) - additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) - elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." - ) - else: - image_latents = torch.cat([image_latents], dim=0) - - if do_classifier_free_guidance: - uncond_image_latents = torch.zeros_like(image_latents) - image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) - - return image_latents + return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file From 1354a957f22a96d6087d0c125892828f24d12a23 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Sun, 16 Jul 2023 23:47:06 +1000 Subject: [PATCH 09/67] Support instruction pix2pix sdxl --- .../train_instruct_pix2pix_xl.py | 54 +++++++------ ...ne_stable_diffusion_xl_instruct_pix2pix.py | 75 ++++--------------- 2 files changed, 45 insertions(+), 84 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py index 5c889824dac2..657ca530aa9b 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py @@ -41,6 +41,7 @@ from packaging import version from torchvision import transforms from tqdm.auto import tqdm + # from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import AutoTokenizer, PretrainedConfig @@ -51,7 +52,9 @@ from diffusers.utils import check_min_version, deprecate, is_wandb_available from diffusers.utils.import_utils import is_xformers_available -from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instruct_pix2pix import ( + StableDiffusionXLInstructPix2PixPipeline, +) from PIL import Image @@ -463,9 +466,7 @@ def main(): tokenizer_2 = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False ) - text_encoder_cls_1 = import_model_class_from_model_name_or_path( - args.pretrained_model_name_or_path, args.revision - ) + text_encoder_cls_1 = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) text_encoder_cls_2 = import_model_class_from_model_name_or_path( args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" ) @@ -651,7 +652,11 @@ def load_model_hook(models, input_dir): # We need to tokenize input captions and transform the images. def tokenize_captions(captions, a_tokenizer): inputs = a_tokenizer( - captions, max_length=a_tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + captions, + max_length=a_tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", ) return inputs.input_ids @@ -714,7 +719,7 @@ def collate_fn(examples): "original_pixel_values": original_pixel_values, "edited_pixel_values": edited_pixel_values, "input_ids": input_ids, - "input_ids_2": input_ids_2, + "input_ids_2": input_ids_2, } # DataLoaders creation: @@ -878,11 +883,13 @@ def collate_fn(examples): # Final text conditioning. ### Begin: Get null conditioning null_conditioning_list = [] - for a_tokenizer, a_text_encoder in zip((tokenizer_1, tokenizer_2), (text_encoder_1, text_encoder_2)): + for a_tokenizer, a_text_encoder in zip( + (tokenizer_1, tokenizer_2), (text_encoder_1, text_encoder_2) + ): null_conditioning_list.append( a_text_encoder( - tokenize_captions([""], a_tokenizer=a_tokenizer).to(accelerator.device), - output_hidden_states=True + tokenize_captions([""], a_tokenizer=a_tokenizer).to(accelerator.device), + output_hidden_states=True, ).hidden_states[-2] ) ### End: Get null conditioning @@ -912,7 +919,7 @@ def collate_fn(examples): ### Begin SDXL add_text_embeds = pooled_prompt_embeds - + crops_coords_top_left = (0, 0) target_size = (512, 512) original_size = original_image_embeds.shape[-2:] @@ -922,13 +929,15 @@ def collate_fn(examples): add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=encoder_hidden_states.dtype) add_time_ids = add_time_ids.to(encoder_hidden_states.device).repeat(args.train_batch_size, 1) ### End SDXL - + # Predict the noise residual and compute loss # tmp_prompt_embeds = torch.load('xl_prompt_embeds.pt').to('cuda') # tmp_concatenated_noisy_latents = torch.load('xl_latent_model_input.pt').to('cuda') added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - - model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs).sample + + model_pred = unet( + concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ).sample # model_pred = unet(concatenated_noisy_latents, timesteps, tmp_prompt_embeds, added_cond_kwargs=added_cond_kwargs).sample loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") @@ -983,11 +992,8 @@ def collate_fn(examples): progress_bar.set_postfix(**logs) ### BEGIN: Perform validation every `validation_epochs` steps - if global_step % args.validation_epochs == 0 or global_step == 1: - if ( - (args.val_image_url is not None) - and (args.validation_prompt is not None) - ): + if global_step % args.validation_epochs == 0 or global_step == 1: + if (args.val_image_url is not None) and (args.validation_prompt is not None): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." @@ -998,7 +1004,7 @@ def collate_fn(examples): # Store the UNet parameters temporarily and load the EMA parameters to perform inference. ema_unet.store(unet.parameters()) ema_unet.copy_to(unet.parameters()) - + # The models need unwrapping because for compatibility in distributed training mode. pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, @@ -1015,9 +1021,9 @@ def collate_fn(examples): pipeline.set_progress_bar_config(disable=True) # run inference - # Save validation images + # Save validation images val_save_dir = os.path.join(args.output_dir, "validation_images") - if not os.path.exists(val_save_dir): + if not os.path.exists(val_save_dir): os.makedirs(val_save_dir) # original_image = download_image(args.val_image_url) @@ -1071,9 +1077,9 @@ def collate_fn(examples): pipeline.set_progress_bar_config(disable=True) # run inference - # Save validation images + # Save validation images val_save_dir = os.path.join(args.output_dir, "validation_images") - if not os.path.exists(val_save_dir): + if not os.path.exists(val_save_dir): os.makedirs(val_save_dir) # original_image = download_image(args.val_image_url) @@ -1165,4 +1171,4 @@ def collate_fn(examples): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index e7a36349cbeb..d6d688089d5f 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -1,17 +1,3 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -356,7 +342,7 @@ def encode_prompt( # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - + prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) @@ -455,44 +441,6 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix - def check_inputs( - self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None - ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) @@ -519,7 +467,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents - + # Copy from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix def prepare_image_latents( self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None @@ -528,7 +476,7 @@ def prepare_image_latents( raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - + image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt @@ -540,7 +488,7 @@ def prepare_image_latents( if self.vae.config.force_upcast: image = image.float() self.vae.to(dtype=torch.float32) - + if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -576,7 +524,7 @@ def prepare_image_latents( image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) return image_latents - + def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype ): @@ -748,7 +696,7 @@ def __call__( "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + # self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) if image is None: raise ValueError("`image` input cannot be undefined.") @@ -891,7 +839,14 @@ def __call__( added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # tar_idx = 2; added_cond_kwargs_tmp = {"text_embeds": add_text_embeds[tar_idx].unsqueeze(0), "time_ids": add_time_ids[tar_idx].unsqueeze(0)} # self.unet(scaled_latent_model_input[tar_idx].unsqueeze(0), t, encoder_hidden_states=prompt_embeds[tar_idx].unsqueeze(0), cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs_tmp, return_dict=False)[0] - noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False)[0] + noise_pred = self.unet( + scaled_latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] # Hack: # For karras style schedulers the model does classifer free guidance using the @@ -966,4 +921,4 @@ def __call__( if not return_dict: return (image,) - return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file + return StableDiffusionXLPipelineOutput(images=image) From 92a71bc0260e70d7fb2d41139331867f6761ecde Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Sun, 16 Jul 2023 23:49:54 +1000 Subject: [PATCH 10/67] Support instruction pix2pix sdxl --- .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index d6d688089d5f..eb066716fa45 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -19,6 +19,7 @@ from ...utils import ( is_accelerate_available, is_accelerate_version, + deprecate, logging, randn_tensor, replace_example_docstring, From 8028af8a41e9223fd39c39ea4e2bc411057a5d75 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Sun, 16 Jul 2023 23:54:36 +1000 Subject: [PATCH 11/67] Support instruction pix2pix sdxl --- .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index eb066716fa45..7622c2e00d8b 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -17,9 +17,9 @@ ) from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + deprecate, is_accelerate_available, is_accelerate_version, - deprecate, logging, randn_tensor, replace_example_docstring, From e5d3ec45c77fda162ca19133e1d7b453b1b810e2 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Mon, 17 Jul 2023 00:18:34 +1000 Subject: [PATCH 12/67] Support instruction pix2pix sdxl --- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 28 ++++++++----------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 7622c2e00d8b..60571072b18b 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -4,31 +4,25 @@ import numpy as np import PIL.Image import torch -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import (CLIPTextModel, CLIPTextModelWithProjection, + CLIPTokenizer) from ...image_processor import VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import (FromSingleFileMixin, LoraLoaderMixin, + TextualInversionLoaderMixin) from ...models import AutoencoderKL, UNet2DConditionModel -from ...models.attention_processor import ( - AttnProcessor2_0, - LoRAAttnProcessor2_0, - LoRAXFormersAttnProcessor, - XFormersAttnProcessor, -) +from ...models.attention_processor import (AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor) from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - deprecate, - is_accelerate_available, - is_accelerate_version, - logging, - randn_tensor, - replace_example_docstring, -) +from ...utils import (deprecate, is_accelerate_available, + is_accelerate_version, logging, randn_tensor, + replace_example_docstring) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionXLPipelineOutput from .watermark import StableDiffusionXLWatermarker - logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ From f6df1e84408b859e060ef26f9cf60a836af38e0d Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Mon, 17 Jul 2023 00:20:26 +1000 Subject: [PATCH 13/67] Support instruction pix2pix sdxl --- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 60571072b18b..7622c2e00d8b 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -4,25 +4,31 @@ import numpy as np import PIL.Image import torch -from transformers import (CLIPTextModel, CLIPTextModelWithProjection, - CLIPTokenizer) +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import VaeImageProcessor -from ...loaders import (FromSingleFileMixin, LoraLoaderMixin, - TextualInversionLoaderMixin) +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel -from ...models.attention_processor import (AttnProcessor2_0, - LoRAAttnProcessor2_0, - LoRAXFormersAttnProcessor, - XFormersAttnProcessor) +from ...models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) from ...schedulers import KarrasDiffusionSchedulers -from ...utils import (deprecate, is_accelerate_available, - is_accelerate_version, logging, randn_tensor, - replace_example_docstring) +from ...utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionXLPipelineOutput from .watermark import StableDiffusionXLWatermarker + logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ From b6eb449186f2239aa9191ce1394718ac14de4723 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Mon, 17 Jul 2023 00:21:37 +1000 Subject: [PATCH 14/67] Support instruction pix2pix sdxl --- .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 7622c2e00d8b..38bcc9f2a6b4 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -1,6 +1,5 @@ -import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union - +import inspect import numpy as np import PIL.Image import torch @@ -28,7 +27,6 @@ from . import StableDiffusionXLPipelineOutput from .watermark import StableDiffusionXLWatermarker - logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ From bc4377e68bf6f46ac3d9b9b738a4f9a820c1c0c2 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Mon, 17 Jul 2023 00:28:33 +1000 Subject: [PATCH 15/67] Support instruction pix2pix sdxl --- .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 38bcc9f2a6b4..7622c2e00d8b 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -1,5 +1,6 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple, Union import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + import numpy as np import PIL.Image import torch @@ -27,6 +28,7 @@ from . import StableDiffusionXLPipelineOutput from .watermark import StableDiffusionXLWatermarker + logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ From 94e9e1041007c0c547830edde564af1f1a3f0590 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Fri, 14 Jul 2023 00:17:05 +1000 Subject: [PATCH 16/67] Support instruction pix2pix sdxl --- .../train_instruct_pix2pix_xl.py | 1051 +++++++++++++++++ ...ne_stable_diffusion_xl_instruct_pix2pix.py | 963 +++++++++++++++ 2 files changed, 2014 insertions(+) create mode 100644 examples/instruct_pix2pix/train_instruct_pix2pix_xl.py create mode 100644 src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py new file mode 100644 index 000000000000..7e708fe8e104 --- /dev/null +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py @@ -0,0 +1,1051 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script to fine-tune Stable Diffusion for InstructPix2Pix.""" + +import argparse +import logging +import math +import os +import shutil +from pathlib import Path + +import accelerate +import datasets +import numpy as np +import PIL +import requests +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +import diffusers +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.18.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + +DATASET_NAME_MAPPING = { + "fusing/instructpix2pix-1000-samples": ("file_name", "edited_image", "edit_prompt"), +} +WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"] + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--original_image_column", + type=str, + default="input_image", + help="The column of the dataset containing the original image on which edits where made.", + ) + parser.add_argument( + "--edited_image_column", + type=str, + default="edited_image", + help="The column of the dataset containing the edited image.", + ) + parser.add_argument( + "--edit_prompt_column", + type=str, + default="edit_prompt", + help="The column of the dataset containing the edit instruction.", + ) + parser.add_argument( + "--val_image_url", + type=str, + default=None, + help="URL to the original image that you would like to edit (used during inference for debugging purposes).", + ) + parser.add_argument( + "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference." + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=1, + help=( + "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="instruct-pix2pix-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=256, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--conditioning_dropout_prob", + type=float, + default=None, + help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.", + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +def convert_to_np(image, resolution): + if isinstance(image, str): + image = PIL.Image.open(image) + image = image.convert("RGB").resize((resolution, resolution)) + return np.array(image).transpose(2, 0, 1) + + +def download_image(url): + image = PIL.Image.open(requests.get(url, stream=True).raw) + image = PIL.ImageOps.exif_transpose(image) + image = image.convert("RGB") + return image + + +def main(): + args = parse_args() + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision + ) + import pdb; pdb.set_trace() + + # InstructPix2Pix uses an additional image for conditioning. To accommodate that, + # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is + # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized + # from the pre-trained checkpoints. For the extra channels added to the first layer, they are + # initialized to zero. + logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.") + # in_channels = 8 + in_channels = 4 + out_channels = unet.conv_in.out_channels + unet.register_to_config(in_channels=in_channels) + + with torch.no_grad(): + new_conv_in = nn.Conv2d( + in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding + ) + new_conv_in.weight.zero_() + new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) + unet.conv_in = new_conv_in + + # Freeze vae and text_encoder + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + + # Create EMA for the unet. + if args.use_ema: + ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if args.use_ema: + ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + optimizer = optimizer_cls( + unet.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/main/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.original_image_column is None: + original_image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + original_image_column = args.original_image_column + if original_image_column not in column_names: + raise ValueError( + f"--original_image_column' value '{args.original_image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.edit_prompt_column is None: + edit_prompt_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + edit_prompt_column = args.edit_prompt_column + if edit_prompt_column not in column_names: + raise ValueError( + f"--edit_prompt_column' value '{args.edit_prompt_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.edited_image_column is None: + edited_image_column = dataset_columns[2] if dataset_columns is not None else column_names[2] + else: + edited_image_column = args.edited_image_column + if edited_image_column not in column_names: + raise ValueError( + f"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images. + def tokenize_captions(captions): + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids + + # Preprocessing the datasets. + train_transforms = transforms.Compose( + [ + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + ] + ) + + def preprocess_images(examples): + original_images = np.concatenate( + [convert_to_np(image, args.resolution) for image in examples[original_image_column]] + ) + edited_images = np.concatenate( + [convert_to_np(image, args.resolution) for image in examples[edited_image_column]] + ) + # We need to ensure that the original and the edited images undergo the same + # augmentation transforms. + images = np.concatenate([original_images, edited_images]) + images = torch.tensor(images) + images = 2 * (images / 255) - 1 + return train_transforms(images) + + def preprocess_train(examples): + # Preprocess images. + preprocessed_images = preprocess_images(examples) + # Since the original and edited images were concatenated before + # applying the transformations, we need to separate them and reshape + # them accordingly. + original_images, edited_images = preprocessed_images.chunk(2) + original_images = original_images.reshape(-1, 3, args.resolution, args.resolution) + edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution) + + # Collate the preprocessed images into the `examples`. + examples["original_pixel_values"] = original_images + examples["edited_pixel_values"] = edited_images + + # Preprocess the captions. + captions = list(examples[edit_prompt_column]) + examples["input_ids"] = tokenize_captions(captions) + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + original_pixel_values = torch.stack([example["original_pixel_values"] for example in examples]) + original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float() + edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples]) + edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float() + input_ids = torch.stack([example["input_ids"] for example in examples]) + return { + "original_pixel_values": original_pixel_values, + "edited_pixel_values": edited_pixel_values, + "input_ids": input_ids, + } + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + if args.use_ema: + ema_unet.to(accelerator.device) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu and cast to weight_dtype + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder_2.to(accelerator.device, dtype=weight_dtype) + text_encoders = [text_encoder, text_encoder_2] + vae.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("instruct-pix2pix", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(unet): + # We want to learn the denoising process w.r.t the edited images which + # are conditioned on the original image (which was edited) and the edit instruction. + # So, first, convert images to latent space. + latents = vae.encode(batch["edited_pixel_values"].to(weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning. + prompt_embeds_list = [] + for text_encoder in text_encoders[1:]: + prompt_embeds = text_encoder(batch["input_ids"], output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, 1, 1) + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + encoder_hidden_states = torch.concat(prompt_embeds_list, dim=-1) + + # Get the additional image embedding for conditioning. + # Instead of getting a diagonal Gaussian here, we simply take the mode. + original_image_embeds = vae.encode(batch["original_pixel_values"].to(weight_dtype)).latent_dist.mode() + + # Conditioning dropout to support classifier-free guidance during inference. For more details + # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800. + if args.conditioning_dropout_prob is not None: + random_p = torch.rand(bsz, device=latents.device, generator=generator) + # Sample masks for the edit prompts. + prompt_mask = random_p < 2 * args.conditioning_dropout_prob + prompt_mask = prompt_mask.reshape(bsz, 1, 1) + # Final text conditioning. + null_conditioning = text_encoder(tokenize_captions([""]).to(accelerator.device))[0] + encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states) + + # Sample masks for the original images. + image_mask_dtype = original_image_embeds.dtype + image_mask = 1 - ( + (random_p >= args.conditioning_dropout_prob).to(image_mask_dtype) + * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype) + ) + image_mask = image_mask.reshape(bsz, 1, 1, 1) + # Final image conditioning. + original_image_embeds = image_mask * original_image_embeds + + # Concatenate the `original_image_embeds` with the `noisy_latents`. + concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1) + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + ### Begin SDXL + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view( + bs_embed * 1, -1 + ) + add_text_embeds = pooled_prompt_embeds + + crops_coords_top_left = (0, 0) + target_size = (args.resolution, args.resolution) + original_size = original_image_embeds.shape[-2:] + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids], dtype=encoder_hidden_states.dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=encoder_hidden_states.dtype) + add_time_ids = add_time_ids.to(encoder_hidden_states.device).repeat(args.train_batch_size, 1) + ### End SDXL + + # Predict the noise residual and compute loss + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + import pdb; pdb.set_trace() + model_pred = unet(concatenated_noisy_latents[:, :4, :, :], timesteps, encoder_hidden_states, cross_attention_kwargs=None, added_cond_kwargs=added_cond_kwargs, return_dict=False).sample + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema_unet.step(unet.parameters()) + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if ( + (args.val_image_url is not None) + and (args.validation_prompt is not None) + and (epoch % args.validation_epochs == 0) + ): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + # The models need unwrapping because for compatibility in distributed training mode. + pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + vae=accelerator.unwrap_model(vae), + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + original_image = download_image(args.val_image_url) + edited_images = [] + with torch.autocast( + str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" + ): + for _ in range(args.num_validation_images): + edited_images.append( + pipeline( + args.validation_prompt, + image=original_image, + num_inference_steps=20, + image_guidance_scale=1.5, + guidance_scale=7, + generator=generator, + ).images[0] + ) + + for tracker in accelerator.trackers: + if tracker.name == "wandb": + wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES) + for edited_image in edited_images: + wandb_table.add_data( + wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt + ) + tracker.log({"validation": wandb_table}) + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(unet.parameters()) + + del pipeline + torch.cuda.empty_cache() + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + if args.use_ema: + ema_unet.copy_to(unet.parameters()) + + pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=accelerator.unwrap_model(text_encoder), + vae=accelerator.unwrap_model(vae), + unet=unet, + revision=args.revision, + ) + pipeline.save_pretrained(args.output_dir) + + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + if args.validation_prompt is not None: + edited_images = [] + pipeline = pipeline.to(accelerator.device) + with torch.autocast(str(accelerator.device).replace(":0", "")): + for _ in range(args.num_validation_images): + edited_images.append( + pipeline( + args.validation_prompt, + image=original_image, + num_inference_steps=20, + image_guidance_scale=1.5, + guidance_scale=7, + generator=generator, + ).images[0] + ) + + for tracker in accelerator.trackers: + if tracker.name == "wandb": + wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES) + for edited_image in edited_images: + wandb_table.add_data( + wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt + ) + tracker.log({"test": wandb_table}) + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py new file mode 100644 index 000000000000..7bfcea0860c0 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -0,0 +1,963 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionXLPipelineOutput +from .watermark import StableDiffusionXLWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" + + >>> init_image = load_image(url).convert("RGB") + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, image=init_image).images[0] + ``` +""" + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + _optional_components = ["tokenizer", "text_encoder"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.watermark = StableDiffusionXLWatermarker() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder_2, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.text_encoder is not None: + cpu_offload(self.text_encoder, device) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + model_sequence = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + model_sequence.extend([self.unet, self.vae]) + + hook = None + for cpu_offloaded_model in model_sequence: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix + def check_inputs( + self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copy from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix + def prepare_image_latents( + self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + image_latents = image + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.mode() + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + + return image_latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + num_inference_steps: int = 100, + guidance_scale: float = 7.5, + image_guidance_scale: float = 1.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): + The image(s) to modify with the pipeline. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + image_guidance_scale (`float`, *optional*, defaults to 1.5): + Image guidance scale is to push the generated image towards the inital image `image`. Image guidance + scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to + generate images that are closely linked to the source image `image`, usually at the expense of lower + image quality. This pipeline requires a value of at least `1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + TODO + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + aesthetic_score (`float`, *optional*, defaults to 6.0): + TODO + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + TDOO + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 and image_guidance_scale >= 1.0 + # check if scheduler is in sigmas space + scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare Image latents + image_latents = self.prepare_image_latents( + image, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + do_classifier_free_guidance, + generator, + ) + + height, width = image_latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 8. Check that shapes of latents and image match the UNet channels + num_channels_image = image_latents.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents + num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 10. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype=prompt_embeds.dtype, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 11. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance. + # The latents are expanded 3 times because for pix2pix the guidance\ + # is applied for both the text and the input image. + latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents + + # concat latents, image_latents in the channel dimension + scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + scaled_latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # Hack: + # For karras style schedulers the model does classifer free guidance using the + # predicted_original_sample instead of the noise_pred. So we need to compute the + # predicted_original_sample here if we are using a karras style scheduler. + if scheduler_is_in_sigma_space: + step_index = (self.scheduler.timesteps == t).nonzero()[0].item() + sigma = self.scheduler.sigmas[step_index] + noise_pred = latent_model_input - sigma * noise_pred + + # perform guidance + if do_classifier_free_guidance: + noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_image) + + image_guidance_scale * (noise_pred_image - noise_pred_uncond) + ) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # Hack: + # For karras style schedulers the model does classifer free guidance using the + # predicted_original_sample instead of the noise_pred. But the scheduler.step function + # expects the noise_pred and computes the predicted_original_sample internally. So we + # need to overwrite the noise_pred here such that the value of the computed + # predicted_original_sample is correct. + if scheduler_is_in_sigma_space: + noise_pred = (noise_pred - latents) / (-sigma) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # make sure the VAE is in float32 mode, as it overflows in float16 + self.vae.to(dtype=torch.float32) + + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(latents.dtype) + self.vae.decoder.conv_in.to(latents.dtype) + self.vae.decoder.mid_block.to(latents.dtype) + else: + latents = latents.float() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + image = self.watermark.apply_watermark(image) + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) From ff0b5cda80e84964d10c9c38d3e7c2e6789cae13 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Fri, 14 Jul 2023 11:06:26 +1000 Subject: [PATCH 17/67] Support instruction pix2pix sdxl --- ...line_stable_diffusion_xl_instruct_pix2pix.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 7bfcea0860c0..33556d060fa2 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -356,7 +356,7 @@ def encode_prompt( # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - + prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) @@ -855,6 +855,10 @@ def __call__( add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + add_text_embeds = torch.concat((add_text_embeds, add_text_embeds.clone()[0].unsqueeze(0)), dim=0) + add_time_ids = torch.concat((add_time_ids, add_time_ids.clone()[0].unsqueeze(0)), dim=0) + prompt_embeds = torch.concat((prompt_embeds, prompt_embeds.clone()[0].unsqueeze(0)), dim=0) + prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) @@ -874,14 +878,7 @@ def __call__( # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - noise_pred = self.unet( - scaled_latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] + noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False)[0] # Hack: # For karras style schedulers the model does classifer free guidance using the @@ -960,4 +957,4 @@ def __call__( if not return_dict: return (image,) - return StableDiffusionXLPipelineOutput(images=image) + return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file From ffb851f0048d07ad71fd108e96f51e505e8fb444 Mon Sep 17 00:00:00 2001 From: Thomas Chambon <36728882+tchambon@users.noreply.github.com> Date: Thu, 13 Jul 2023 16:49:41 +0200 Subject: [PATCH 18/67] [Community] Implementation of the IADB community pipeline (#3996) * community pipeline: implementation of iadb * iadb.py: reformat using black * iadb.py: linting update --- examples/community/README.md | 60 ++++++++++++++ examples/community/iadb.py | 149 +++++++++++++++++++++++++++++++++++ 2 files changed, 209 insertions(+) create mode 100644 examples/community/iadb.py diff --git a/examples/community/README.md b/examples/community/README.md index 17cd34a5182d..6967d273e449 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -38,6 +38,7 @@ If a community doesn't work as expected, please open an issue and ping the autho | Stable Diffusion IPEX Pipeline | Accelerate Stable Diffusion inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion on IPEX](#stable-diffusion-on-ipex) | - | [Yingjie Han](https://github.com/yingjie-han/) | | CLIP Guided Images Mixing Stable Diffusion Pipeline | Сombine images using usual diffusion models. | [CLIP Guided Images Mixing Using Stable Diffusion](#clip-guided-images-mixing-with-stable-diffusion) | - | [Karachev Denis](https://github.com/TheDenk) | | TensorRT Stable Diffusion Inpainting Pipeline | Accelerates the Stable Diffusion Inpainting Pipeline using TensorRT | [TensorRT Stable Diffusion Inpainting Pipeline](#tensorrt-inpainting-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) | +| IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon) To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly. ```py @@ -1707,3 +1708,62 @@ output = pipeline( ``` ![Input_Image](https://huggingface.co/datasets/kadirnar/diffusers_readme_images/resolve/main/input_image.png) ![mixture_canvas_results](https://huggingface.co/datasets/kadirnar/diffusers_readme_images/resolve/main/canvas.png) + + +### IADB pipeline + +This pipeline is the implementation of the [α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) paper. +It is a simple and minimalist diffusion model. + +The following code shows how to use the IADB pipeline to generate images using a pretrained celebahq-256 model. + +```python + +pipeline_iadb = DiffusionPipeline.from_pretrained("thomasc4/iadb-celebahq-256", custom_pipeline='iadb') + +pipeline_iadb = pipeline_iadb.to('cuda') + +output = pipeline_iadb(batch_size=4,num_inference_steps=128) +for i in range(len(output[0])): + plt.imshow(output[0][i]) + plt.show() + +``` + +Sampling with the IADB formulation is easy, and can be done in a few lines (the pipeline already implements it): + +```python + +def sample_iadb(model, x0, nb_step): + x_alpha = x0 + for t in range(nb_step): + alpha = (t/nb_step) + alpha_next =((t+1)/nb_step) + + d = model(x_alpha, torch.tensor(alpha, device=x_alpha.device))['sample'] + x_alpha = x_alpha + (alpha_next-alpha)*d + + return x_alpha + +``` + +The training loop is also straightforward: + +```python + +# Training loop +while True: + x0 = sample_noise() + x1 = sample_dataset() + + alpha = torch.rand(batch_size) + + # Blend + x_alpha = (1-alpha) * x0 + alpha * x1 + + # Loss + loss = torch.sum((D(x_alpha, alpha)- (x1-x0))**2) + optimizer.zero_grad() + loss.backward() + optimizer.step() +``` diff --git a/examples/community/iadb.py b/examples/community/iadb.py new file mode 100644 index 000000000000..1f421ee0ea4c --- /dev/null +++ b/examples/community/iadb.py @@ -0,0 +1,149 @@ +from typing import List, Optional, Tuple, Union + +import torch + +from diffusers import DiffusionPipeline +from diffusers.configuration_utils import ConfigMixin +from diffusers.pipeline_utils import ImagePipelineOutput +from diffusers.schedulers.scheduling_utils import SchedulerMixin + + +class IADBScheduler(SchedulerMixin, ConfigMixin): + """ + IADBScheduler is a scheduler for the Iterative α-(de)Blending denoising method. It is simple and minimalist. + + For more details, see the original paper: https://arxiv.org/abs/2305.03486 and the blog post: https://ggx-research.github.io/publication/2023/05/10/publication-iadb.html + """ + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + x_alpha: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + Predict the sample at the previous timestep by reversing the ODE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. It is the direction from x0 to x1. + timestep (`float`): current timestep in the diffusion chain. + x_alpha (`torch.FloatTensor`): x_alpha sample for the current timestep + + Returns: + `torch.FloatTensor`: the sample at the previous timestep + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + alpha = timestep / self.num_inference_steps + alpha_next = (timestep + 1) / self.num_inference_steps + + d = model_output + + x_alpha = x_alpha + (alpha_next - alpha) * d + + return x_alpha + + def set_timesteps(self, num_inference_steps: int): + self.num_inference_steps = num_inference_steps + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + alpha: torch.FloatTensor, + ) -> torch.FloatTensor: + return original_samples * alpha + noise * (1 - alpha) + + def __len__(self): + return self.config.num_train_timesteps + + +class IADBPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of + [`DDPMScheduler`], or [`DDIMScheduler`]. + """ + + def __init__(self, unet, scheduler): + super().__init__() + + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + # Sample gaussian noise to begin loop + if isinstance(self.unet.config.sample_size, int): + image_shape = ( + batch_size, + self.unet.config.in_channels, + self.unet.config.sample_size, + self.unet.config.sample_size, + ) + else: + image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + image = torch.randn(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + x_alpha = image.clone() + for t in self.progress_bar(range(num_inference_steps)): + alpha = t / num_inference_steps + + # 1. predict noise model_output + model_output = self.unet(x_alpha, torch.tensor(alpha, device=x_alpha.device)).sample + + # 2. step + x_alpha = self.scheduler.step(model_output, t, x_alpha) + + image = (x_alpha * 0.5 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) From 926fc6d657142437aeebc95e838e3dd2bcbfaf0b Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 13 Jul 2023 10:26:51 -1000 Subject: [PATCH 19/67] add kandinsky to readme table (#4081) Co-authored-by: yiyixuxu --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 6c66be9f2463..7d9811362597 100644 --- a/README.md +++ b/README.md @@ -146,6 +146,11 @@ just hang out ☕. if DeepFloyd/IF-I-XL-v1.0 + + Text-to-Image + Kandinsky + kandinsky-community/kandinsky-2-2-decoder + Text-guided Image-to-Image Controlnet From 7f0a22b1b5b877bb4544b31fdf5ce216d32d4b21 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 13 Jul 2023 23:16:43 +0200 Subject: [PATCH 20/67] [From Single File] Force accelerate to be installed (#4078) force accelerate to be installed --- .../pipelines/stable_diffusion/convert_from_ckpt.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index a9094cf12f79..599f7826b7d5 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -1172,6 +1172,11 @@ def download_from_original_stable_diffusion_ckpt( StableUnCLIPPipeline, ) + if not is_accelerate_available(): + raise ImportError( + "To correctly use `from_single_file`, please make sure that `accelerate` is installed. You can install it with `pip install accelerate`." + ) + if pipeline_class is None: pipeline_class = StableDiffusionPipeline From cc8507dbbda1a759bd385b56c495a896c4009744 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Sat, 15 Jul 2023 22:37:05 +1000 Subject: [PATCH 21/67] Support instruction pix2pix sdxl --- .../train_instruct_pix2pix_xl.py | 2 -- ...line_stable_diffusion_xl_instruct_pix2pix.py | 17 +++++++++++++---- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py index 7e708fe8e104..c3285074b0bb 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py @@ -448,7 +448,6 @@ def main(): unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision ) - import pdb; pdb.set_trace() # InstructPix2Pix uses an additional image for conditioning. To accommodate that, # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is @@ -878,7 +877,6 @@ def collate_fn(examples): # Predict the noise residual and compute loss added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - import pdb; pdb.set_trace() model_pred = unet(concatenated_noisy_latents[:, :4, :, :], timesteps, encoder_hidden_states, cross_attention_kwargs=None, added_cond_kwargs=added_cond_kwargs, return_dict=False).sample loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 33556d060fa2..b11d7cf27ba9 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -528,7 +528,8 @@ def prepare_image_latents( raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - + + tmp_img = torch.load('/home/users/u5689359/gitRepo_mill/Lycium/tmp_image.pt').to('cuda') image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt @@ -536,6 +537,11 @@ def prepare_image_latents( if image.shape[1] == 4: image_latents = image else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -859,11 +865,12 @@ def __call__( add_time_ids = torch.concat((add_time_ids, add_time_ids.clone()[0].unsqueeze(0)), dim=0) prompt_embeds = torch.concat((prompt_embeds, prompt_embeds.clone()[0].unsqueeze(0)), dim=0) - prompt_embeds = prompt_embeds.to(device) - add_text_embeds = add_text_embeds.to(device) + prompt_embeds = prompt_embeds.to(device).to(torch.float32) + add_text_embeds = add_text_embeds.to(device).to(torch.float32) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # 11. Denoising loop + self.unet = self.unet.to(torch.float32) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -878,6 +885,8 @@ def __call__( # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + tar_idx = 2; added_cond_kwargs_tmp = {"text_embeds": add_text_embeds[tar_idx].unsqueeze(0), "time_ids": add_time_ids[tar_idx].unsqueeze(0)} + # self.unet(scaled_latent_model_input[tar_idx].unsqueeze(0), t, encoder_hidden_states=prompt_embeds[tar_idx].unsqueeze(0), cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs_tmp, return_dict=False)[0] noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False)[0] # Hack: @@ -957,4 +966,4 @@ def __call__( if not return_dict: return (image,) - return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file + return StableDiffusionXLPipelineOutput(images=image) From d0790d84946565bb3fa4b8a3108561850747f372 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Sun, 16 Jul 2023 15:17:36 +1000 Subject: [PATCH 22/67] Support instruction pix2pix sdxl --- .../train_instruct_pix2pix_xl.py | 83 +++++++++++++------ 1 file changed, 56 insertions(+), 27 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py index c3285074b0bb..3ae320692884 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py @@ -52,6 +52,8 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline +from PIL import Image + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.18.0.dev0") @@ -435,10 +437,13 @@ def main(): # Load scheduler, tokenizer and models. noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - tokenizer = CLIPTokenizer.from_pretrained( + tokenizer_1 = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision ) - text_encoder = CLIPTextModel.from_pretrained( + tokenizer_2 = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision + ) + text_encoder_1 = CLIPTextModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( @@ -455,8 +460,7 @@ def main(): # from the pre-trained checkpoints. For the extra channels added to the first layer, they are # initialized to zero. logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.") - # in_channels = 8 - in_channels = 4 + in_channels = 8 out_channels = unet.conv_in.out_channels unet.register_to_config(in_channels=in_channels) @@ -470,7 +474,8 @@ def main(): # Freeze vae and text_encoder vae.requires_grad_(False) - text_encoder.requires_grad_(False) + text_encoder_1.requires_grad_(False) + text_encoder_2.requires_grad_(False) # Create EMA for the unet. if args.use_ema: @@ -614,9 +619,9 @@ def load_model_hook(models, input_dir): # Preprocessing the datasets. # We need to tokenize input captions and transform the images. - def tokenize_captions(captions): - inputs = tokenizer( - captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + def tokenize_captions(captions, a_tokenizer): + inputs = a_tokenizer( + captions, max_length=a_tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ) return inputs.input_ids @@ -658,7 +663,8 @@ def preprocess_train(examples): # Preprocess the captions. captions = list(examples[edit_prompt_column]) - examples["input_ids"] = tokenize_captions(captions) + examples["input_ids"] = tokenize_captions(captions, tokenizer_1) + examples["input_ids_2"] = tokenize_captions(captions, tokenizer_2) return examples with accelerator.main_process_first(): @@ -673,10 +679,12 @@ def collate_fn(examples): edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples]) edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = torch.stack([example["input_ids"] for example in examples]) + input_ids_2 = torch.stack([example["input_ids_2"] for example in examples]) return { "original_pixel_values": original_pixel_values, "edited_pixel_values": edited_pixel_values, "input_ids": input_ids, + "input_ids_2": input_ids_2, } # DataLoaders creation: @@ -719,9 +727,9 @@ def collate_fn(examples): weight_dtype = torch.bfloat16 # Move text_encode and vae to gpu and cast to weight_dtype - text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder_1.to(accelerator.device, dtype=weight_dtype) text_encoder_2.to(accelerator.device, dtype=weight_dtype) - text_encoders = [text_encoder, text_encoder_2] + text_encoders = [text_encoder_1, text_encoder_2] vae.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. @@ -792,6 +800,7 @@ def collate_fn(examples): # We want to learn the denoising process w.r.t the edited images which # are conditioned on the original image (which was edited) and the edit instruction. # So, first, convert images to latent space. + # tmp_pixel_value = torch.load('xl_image.pt').to('cuda') latents = vae.encode(batch["edited_pixel_values"].to(weight_dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor @@ -806,21 +815,24 @@ def collate_fn(examples): # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - # Get the text embedding for conditioning. + ### Begin encoder prompt prompt_embeds_list = [] - for text_encoder in text_encoders[1:]: - prompt_embeds = text_encoder(batch["input_ids"], output_hidden_states=True) + for input_ids, text_encoder in zip((batch["input_ids"], batch["input_ids_2"]), text_encoders): + prompt_embeds = text_encoder(input_ids, output_hidden_states=True) + # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, 1, 1) prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) - + encoder_hidden_states = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(bs_embed, -1) + ### End encoder prompt # Get the additional image embedding for conditioning. # Instead of getting a diagonal Gaussian here, we simply take the mode. @@ -834,7 +846,17 @@ def collate_fn(examples): prompt_mask = random_p < 2 * args.conditioning_dropout_prob prompt_mask = prompt_mask.reshape(bsz, 1, 1) # Final text conditioning. - null_conditioning = text_encoder(tokenize_captions([""]).to(accelerator.device))[0] + ### Begin: Get null conditioning + null_conditioning_list = [] + for a_tokenizer, a_text_encoder in zip((tokenizer_1, tokenizer_2), (text_encoder_1, text_encoder_2)): + null_conditioning_list.append( + a_text_encoder( + tokenize_captions([""], a_tokenizer=a_tokenizer).to(accelerator.device), + output_hidden_states=True + ).hidden_states[-2] + ) + ### End: Get null conditioning + null_conditioning = torch.concat(null_conditioning_list, dim=-1) encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states) # Sample masks for the original images. @@ -859,14 +881,10 @@ def collate_fn(examples): raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") ### Begin SDXL - bs_embed = pooled_prompt_embeds.shape[0] - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view( - bs_embed * 1, -1 - ) add_text_embeds = pooled_prompt_embeds crops_coords_top_left = (0, 0) - target_size = (args.resolution, args.resolution) + target_size = (512, 512) original_size = original_image_embeds.shape[-2:] add_time_ids = list(original_size + crops_coords_top_left + target_size) add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) @@ -876,8 +894,12 @@ def collate_fn(examples): ### End SDXL # Predict the noise residual and compute loss + # tmp_prompt_embeds = torch.load('xl_prompt_embeds.pt').to('cuda') + # tmp_concatenated_noisy_latents = torch.load('xl_latent_model_input.pt').to('cuda') added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - model_pred = unet(concatenated_noisy_latents[:, :4, :, :], timesteps, encoder_hidden_states, cross_attention_kwargs=None, added_cond_kwargs=added_cond_kwargs, return_dict=False).sample + + model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs).sample + # model_pred = unet(concatenated_noisy_latents, timesteps, tmp_prompt_embeds, added_cond_kwargs=added_cond_kwargs).sample loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Gather the losses across all processes for logging (if we use distributed training). @@ -952,7 +974,10 @@ def collate_fn(examples): pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), + text_encoder=accelerator.unwrap_model(text_encoder_1), + text_encoder_2=accelerator.unwrap_model(text_encoder_2), + tokenizer=accelerator.unwrap_model(tokenizer_1), + tokenizer_2=accelerator.unwrap_model(tokenizer_2), vae=accelerator.unwrap_model(vae), revision=args.revision, torch_dtype=weight_dtype, @@ -961,7 +986,8 @@ def collate_fn(examples): pipeline.set_progress_bar_config(disable=True) # run inference - original_image = download_image(args.val_image_url) + # original_image = download_image(args.val_image_url) + original_image = Image.open(args.val_image_url).convert("RGB") edited_images = [] with torch.autocast( str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" @@ -1002,7 +1028,10 @@ def collate_fn(examples): pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, - text_encoder=accelerator.unwrap_model(text_encoder), + text_encoder=accelerator.unwrap_model(text_encoder_1), + text_encoder_2=accelerator.unwrap_model(text_encoder_2), + tokenizer=accelerator.unwrap_model(tokenizer_1), + tokenizer_2=accelerator.unwrap_model(tokenizer_2), vae=accelerator.unwrap_model(vae), unet=unet, revision=args.revision, @@ -1046,4 +1075,4 @@ def collate_fn(examples): if __name__ == "__main__": - main() + main() \ No newline at end of file From bfe12e95194af4dd34645a645b6618c9810722d0 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Sun, 16 Jul 2023 17:14:11 +1000 Subject: [PATCH 23/67] Support instruction pix2pix sdxl --- .../train_instruct_pix2pix_xl.py | 62 ++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py index 3ae320692884..572a4dad2ca3 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py @@ -144,7 +144,7 @@ def parse_args(): parser.add_argument( "--validation_epochs", type=int, - default=1, + default=100, help=( "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" " `args.validation_prompt` multiple times: `args.num_validation_images`." @@ -952,6 +952,61 @@ def collate_fn(examples): logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) + ### BEGIN: Perform validation every `validation_epochs` steps + if global_step % args.validation_epochs == 0 or global_step == 1: + if ( + (args.val_image_url is not None) + and (args.validation_prompt is not None) + ): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + + # create pipeline + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + + # The models need unwrapping because for compatibility in distributed training mode. + pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder_1), + text_encoder_2=accelerator.unwrap_model(text_encoder_2), + tokenizer=accelerator.unwrap_model(tokenizer_1), + tokenizer_2=accelerator.unwrap_model(tokenizer_2), + vae=accelerator.unwrap_model(vae), + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + # Save validation images + val_save_dir = os.path.join(args.output_dir, "validation_images") + if not os.path.exists(val_save_dir): + os.makedirs(val_save_dir) + + # original_image = download_image(args.val_image_url) + original_image = Image.open(args.val_image_url).convert("RGB") + with torch.autocast( + str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" + ): + for val_img_idx in range(args.num_validation_images): + a_val_img = pipeline( + args.validation_prompt, + image=original_image, + num_inference_steps=20, + image_guidance_scale=1.5, + guidance_scale=7, + generator=generator, + ).images[0] + a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png")) + ### END: Perform validation every `validation_epochs` steps + if global_step >= args.max_train_steps: break @@ -986,6 +1041,11 @@ def collate_fn(examples): pipeline.set_progress_bar_config(disable=True) # run inference + # Save validation images + val_save_dir = os.path.join(args.output_dir, "validation_images") + if not os.path.exists(val_save_dir): + os.makedirs(val_save_dir) + # original_image = download_image(args.val_image_url) original_image = Image.open(args.val_image_url).convert("RGB") edited_images = [] From 84f5bc98e8e9aac12f1db80a9eaf34c9cfd7920b Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Sun, 16 Jul 2023 19:37:12 +1000 Subject: [PATCH 24/67] Support instruction pix2pix sdxl --- .../train_instruct_pix2pix_xl.py | 50 +++++++++++++++---- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 44 ++++++++-------- 2 files changed, 62 insertions(+), 32 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py index 572a4dad2ca3..5c889824dac2 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py @@ -41,7 +41,8 @@ from packaging import version from torchvision import transforms from tqdm.auto import tqdm -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +# from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import AutoTokenizer, PretrainedConfig import diffusers from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel @@ -66,6 +67,26 @@ WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"] +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.") parser.add_argument( @@ -436,22 +457,31 @@ def main(): ).repo_id # Load scheduler, tokenizer and models. - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - tokenizer_1 = CLIPTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + tokenizer_1 = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False + ) + tokenizer_2 = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False + ) + text_encoder_cls_1 = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision ) - tokenizer_2 = CLIPTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision + text_encoder_cls_2 = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" ) - text_encoder_1 = CLIPTextModel.from_pretrained( + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder_1 = text_encoder_cls_1.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) - text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( + text_encoder_2 = text_encoder_cls_2.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) # InstructPix2Pix uses an additional image for conditioning. To accommodate that, @@ -459,7 +489,7 @@ def main(): # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized # from the pre-trained checkpoints. For the extra channels added to the first layer, they are # initialized to zero. - logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.") + logger.info("Initializing the XL InstructPix2Pix UNet from the pretrained UNet.") in_channels = 8 out_channels = unet.conv_in.out_channels unet.register_to_config(in_channels=in_channels) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index b11d7cf27ba9..aef1d22021f1 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -529,7 +529,6 @@ def prepare_image_latents( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - tmp_img = torch.load('/home/users/u5689359/gitRepo_mill/Lycium/tmp_image.pt').to('cuda') image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt @@ -794,22 +793,26 @@ def __call__( ) # 4. Preprocess image - image = self.image_processor.preprocess(image) + image = self.image_processor.preprocess(image).to(device) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 6. Prepare Image latents - image_latents = self.prepare_image_latents( - image, - batch_size, - num_images_per_prompt, - prompt_embeds.dtype, - device, - do_classifier_free_guidance, - generator, - ) + # image_latents = self.prepare_image_latents( + # image, + # batch_size, + # num_images_per_prompt, + # prompt_embeds.dtype, + # device, + # do_classifier_free_guidance, + # generator, + # ) + image_latents = self.vae.encode(image).latent_dist.mode() + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([uncond_image_latents, image_latents], dim=0) height, width = image_latents.shape[-2:] height = height * self.vae_scale_factor @@ -861,9 +864,9 @@ def __call__( add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) - add_text_embeds = torch.concat((add_text_embeds, add_text_embeds.clone()[0].unsqueeze(0)), dim=0) - add_time_ids = torch.concat((add_time_ids, add_time_ids.clone()[0].unsqueeze(0)), dim=0) - prompt_embeds = torch.concat((prompt_embeds, prompt_embeds.clone()[0].unsqueeze(0)), dim=0) + # add_text_embeds = torch.concat((add_text_embeds, add_text_embeds.clone()[0].unsqueeze(0)), dim=0) + # add_time_ids = torch.concat((add_time_ids, add_time_ids.clone()[0].unsqueeze(0)), dim=0) + # prompt_embeds = torch.concat((prompt_embeds, prompt_embeds.clone()[0].unsqueeze(0)), dim=0) prompt_embeds = prompt_embeds.to(device).to(torch.float32) add_text_embeds = add_text_embeds.to(device).to(torch.float32) @@ -877,7 +880,8 @@ def __call__( # Expand the latents if we are doing classifier free guidance. # The latents are expanded 3 times because for pix2pix the guidance\ # is applied for both the text and the input image. - latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents + # latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # concat latents, image_latents in the channel dimension scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -885,7 +889,7 @@ def __call__( # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - tar_idx = 2; added_cond_kwargs_tmp = {"text_embeds": add_text_embeds[tar_idx].unsqueeze(0), "time_ids": add_time_ids[tar_idx].unsqueeze(0)} + # tar_idx = 2; added_cond_kwargs_tmp = {"text_embeds": add_text_embeds[tar_idx].unsqueeze(0), "time_ids": add_time_ids[tar_idx].unsqueeze(0)} # self.unet(scaled_latent_model_input[tar_idx].unsqueeze(0), t, encoder_hidden_states=prompt_embeds[tar_idx].unsqueeze(0), cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs_tmp, return_dict=False)[0] noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False)[0] @@ -900,12 +904,8 @@ def __call__( # perform guidance if do_classifier_free_guidance: - noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) - noise_pred = ( - noise_pred_uncond - + guidance_scale * (noise_pred_text - noise_pred_image) - + image_guidance_scale * (noise_pred_image - noise_pred_uncond) - ) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf From f827c0a134b6a48e839f23a7604b9aa8eb617401 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Sun, 16 Jul 2023 23:24:10 +1000 Subject: [PATCH 25/67] Support instruction pix2pix sdxl --- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 871 ++++++++++++------ 1 file changed, 578 insertions(+), 293 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index aef1d22021f1..2d3ad8f8eb30 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -356,7 +356,7 @@ def encode_prompt( # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - + prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) @@ -457,167 +457,154 @@ def prepare_extra_step_kwargs(self, generator, eta): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix def check_inputs( - self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None - ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) +# Copyright 2023 The InstructPix2Pix Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. - def get_timesteps(self, num_inference_steps, strength, device): - # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) +import inspect +import warnings +from typing import Callable, List, Optional, Union - t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer - return timesteps, num_inference_steps - t_start +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, +) +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) +logger = logging.get_logger(__name__) # pylint: disable=invalid-name - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - # Copy from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix - def prepare_image_latents( - self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None - ): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - image = image.to(device=device, dtype=dtype) - batch_size = batch_size * num_images_per_prompt +def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion. - if image.shape[1] == 4: - image_latents = image - else: - # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.config.force_upcast: - image = image.float() - self.vae.to(dtype=torch.float32) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - if isinstance(generator, list): - image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = self.vae.encode(image).latent_dist.mode() + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: - # expand image_latents for batch_size - deprecation_message = ( - f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" - " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" - " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" - " your script to pass as many initial images as text prompts to suppress this warning." - ) - deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) - additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) - elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." - ) - else: - image_latents = torch.cat([image_latents], dim=0) + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] - if do_classifier_free_guidance: - uncond_image_latents = torch.zeros_like(image_latents) - image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] - return image_latents - - def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, ): - if self.config.requires_aesthetics_score: - add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) - add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) - else: - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) - - passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim - ) - expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + super().__init__() - if ( - expected_add_embed_dim > passed_add_embed_dim - and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) - elif ( - expected_add_embed_dim < passed_add_embed_dim - and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." - ) - elif expected_add_embed_dim != passed_add_embed_dim: + + if safety_checker is not None and feature_extractor is None: raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) - - return add_time_ids, add_neg_time_ids + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, @@ -639,19 +626,10 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - guidance_rescale: float = 0.0, - original_size: Tuple[int, int] = None, - crops_coords_top_left: Tuple[int, int] = (0, 0), - target_size: Tuple[int, int] = None, - aesthetic_score: float = 6.0, - negative_aesthetic_score: float = 2.5, ): r""" Function invoked when calling the pipeline for generation. @@ -660,9 +638,10 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): - The image(s) to modify with the pipeline. - num_inference_steps (`int`, *optional*, defaults to 50): + image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch which will be repainted according to `prompt`. Can also + accpet image latents as `image`, if passing latents directly, it will not be encoded again. + num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): @@ -670,7 +649,7 @@ def __call__( `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + usually at the expense of lower image quality. This pipeline requires a value of at least `1`. image_guidance_scale (`float`, *optional*, defaults to 1.5): Image guidance scale is to push the generated image towards the inital image `image`. Image guidance scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to @@ -678,14 +657,14 @@ def __call__( image quality. This pipeline requires a value of at least `1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): @@ -699,18 +678,11 @@ def __call__( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be @@ -718,42 +690,50 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - guidance_rescale (`float`, *optional*, defaults to 0.7): - Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of - [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). - Guidance rescale factor should fix overexposure when using zero terminal SNR. - original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): - TODO - crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): - TODO - target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): - TODO - aesthetic_score (`float`, *optional*, defaults to 6.0): - TODO - negative_aesthetic_score (`float`, *optional*, defaults to 2.5): - TDOO Examples: + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionInstructPix2PixPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" + + >>> image = download_image(img_url).resize((512, 512)) + + >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( + ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "make the mountains snowy" + >>> image = pipe(prompt=prompt, image=image).images[0] + ``` + Returns: - [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a - `tuple. When returning a tuple, the first element is a list with the generated images, and the second - element is a list of `bool`s denoting whether the corresponding generated image likely represents - "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. """ - # 1. Check inputs. Raise error if not correct + # 0. Check inputs self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) if image is None: raise ValueError("`image` input cannot be undefined.") - # 2. Define call parameters + # 1. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -762,7 +742,6 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @@ -770,16 +749,8 @@ def __call__( # check if scheduler is in sigmas space scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") - # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.encode_prompt( + # 2. Encode input prompt + prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, @@ -787,32 +758,25 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - lora_scale=text_encoder_lora_scale, ) - # 4. Preprocess image - image = self.image_processor.preprocess(image).to(device) + # 3. Preprocess image + image = self.image_processor.preprocess(image) - # 5. Prepare timesteps + # 4. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps - # 6. Prepare Image latents - # image_latents = self.prepare_image_latents( - # image, - # batch_size, - # num_images_per_prompt, - # prompt_embeds.dtype, - # device, - # do_classifier_free_guidance, - # generator, - # ) - image_latents = self.vae.encode(image).latent_dist.mode() - if do_classifier_free_guidance: - uncond_image_latents = torch.zeros_like(image_latents) - image_latents = torch.cat([uncond_image_latents, image_latents], dim=0) + # 5. Prepare Image latents + image_latents = self.prepare_image_latents( + image, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + do_classifier_free_guidance, + generator, + ) height, width = image_latents.shape[-2:] height = height * self.vae_scale_factor @@ -831,67 +795,37 @@ def __call__( latents, ) - # 8. Check that shapes of latents and image match the UNet channels + # 7. Check that shapes of latents and image match the UNet channels num_channels_image = image_latents.shape[1] if num_channels_latents + num_channels_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_image`: {num_channels_image} " - f" = {num_channels_latents + num_channels_image}. Please verify the config of" + f" = {num_channels_latents+num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) - # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - original_size = original_size or (height, width) - target_size = target_size or (height, width) - - # 10. Prepare added time ids & embeddings - add_text_embeds = pooled_prompt_embeds - add_time_ids, add_neg_time_ids = self._get_add_time_ids( - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - dtype=prompt_embeds.dtype, - ) - - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) - - # add_text_embeds = torch.concat((add_text_embeds, add_text_embeds.clone()[0].unsqueeze(0)), dim=0) - # add_time_ids = torch.concat((add_time_ids, add_time_ids.clone()[0].unsqueeze(0)), dim=0) - # prompt_embeds = torch.concat((prompt_embeds, prompt_embeds.clone()[0].unsqueeze(0)), dim=0) - - prompt_embeds = prompt_embeds.to(device).to(torch.float32) - add_text_embeds = add_text_embeds.to(device).to(torch.float32) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) - - # 11. Denoising loop - self.unet = self.unet.to(torch.float32) + # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Expand the latents if we are doing classifier free guidance. # The latents are expanded 3 times because for pix2pix the guidance\ # is applied for both the text and the input image. - # latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents # concat latents, image_latents in the channel dimension scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) # predict the noise residual - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - # tar_idx = 2; added_cond_kwargs_tmp = {"text_embeds": add_text_embeds[tar_idx].unsqueeze(0), "time_ids": add_time_ids[tar_idx].unsqueeze(0)} - # self.unet(scaled_latent_model_input[tar_idx].unsqueeze(0), t, encoder_hidden_states=prompt_embeds[tar_idx].unsqueeze(0), cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs_tmp, return_dict=False)[0] - noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False)[0] + noise_pred = self.unet( + scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False + )[0] # Hack: # For karras style schedulers the model does classifer free guidance using the @@ -904,12 +838,12 @@ def __call__( # perform guidance if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - if do_classifier_free_guidance and guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_image) + + image_guidance_scale * (noise_pred_image - noise_pred_uncond) + ) # Hack: # For karras style schedulers the model does classifer free guidance using the @@ -929,41 +863,392 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # make sure the VAE is in float32 mode, as it overflows in float16 - self.vae.to(dtype=torch.float32) - - use_torch_2_0_or_xformers = isinstance( - self.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - self.vae.post_quant_conv.to(latents.dtype) - self.vae.decoder.conv_in.to(latents.dtype) - self.vae.decoder.mid_block.to(latents.dtype) - else: - latents = latents.float() - if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents - return StableDiffusionXLPipelineOutput(images=image) + has_nsfw_concept = None - image = self.watermark.apply_watermark(image) - image = self.image_processor.postprocess(image, output_type=output_type) + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: - return (image,) + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_ prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] + prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]) + + return prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept - return StableDiffusionXLPipelineOutput(images=image) + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs( + self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image_latents( + self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + image_latents = image + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.mode() + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + + return image_latents From 4a76d36de014137525dcbe30b55042c9f426d8de Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Sun, 16 Jul 2023 23:28:32 +1000 Subject: [PATCH 26/67] Support instruction pix2pix sdxl --- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 871 ++++++------------ 1 file changed, 293 insertions(+), 578 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 2d3ad8f8eb30..e7a36349cbeb 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -356,7 +356,7 @@ def encode_prompt( # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - + prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) @@ -457,154 +457,167 @@ def prepare_extra_step_kwargs(self, generator, eta): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix def check_inputs( -# Copyright 2023 The InstructPix2Pix Authors and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. + self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) -import inspect -import warnings -from typing import Callable, List, Optional, Union + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") -import numpy as np -import PIL -import torch -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) -from ...image_processor import VaeImageProcessor -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - PIL_INTERPOLATION, - deprecate, - is_accelerate_available, - is_accelerate_version, - logging, - randn_tensor, -) -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) -logger = logging.get_logger(__name__) # pylint: disable=invalid-name + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + return timesteps, num_inference_steps - t_start -def preprocess(image): - warnings.warn( - "The preprocess method is deprecated and will be removed in a future version. Please" - " use VaeImageProcessor.preprocess instead", - FutureWarning, - ) - if isinstance(image, torch.Tensor): - return image - elif isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - w, h = image[0].size - w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 - - image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = 2.0 * image - 1.0 - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - return image - - -class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): - r""" - Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion. + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) - In addition the pipeline inherits the following loading methods: - - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copy from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix + def prepare_image_latents( + self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) - as well as the following saving methods: - - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + batch_size = batch_size * num_images_per_prompt - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPImageProcessor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - _optional_components = ["safety_checker", "feature_extractor"] + if image.shape[1] == 4: + image_latents = image + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPImageProcessor, - requires_safety_checker: bool = True, - ): - super().__init__() + if isinstance(generator, list): + image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.mode() - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." ) - - if safety_checker is not None and feature_extractor is None: + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." ) + else: + image_latents = torch.cat([image_latents], dim=0) - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + + return image_latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.register_to_config(requires_safety_checker=requires_safety_checker) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, @@ -626,10 +639,19 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, ): r""" Function invoked when calling the pipeline for generation. @@ -638,10 +660,9 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): - `Image`, or tensor representing an image batch which will be repainted according to `prompt`. Can also - accpet image latents as `image`, if passing latents directly, it will not be encoded again. - num_inference_steps (`int`, *optional*, defaults to 100): + image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): + The image(s) to modify with the pipeline. + num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): @@ -649,7 +670,7 @@ def __call__( `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. This pipeline requires a value of at least `1`. + usually at the expense of lower image quality. image_guidance_scale (`float`, *optional*, defaults to 1.5): Image guidance scale is to push the generated image towards the inital image `image`. Image guidance scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to @@ -657,14 +678,14 @@ def __call__( image quality. This pipeline requires a value of at least `1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` - is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): @@ -678,11 +699,18 @@ def __call__( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be @@ -690,50 +718,42 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + TODO + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + aesthetic_score (`float`, *optional*, defaults to 6.0): + TODO + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + TDOO Examples: - ```py - >>> import PIL - >>> import requests - >>> import torch - >>> from io import BytesIO - - >>> from diffusers import StableDiffusionInstructPix2PixPipeline - - - >>> def download_image(url): - ... response = requests.get(url) - ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") - - - >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" - - >>> image = download_image(img_url).resize((512, 512)) - - >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( - ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16 - ... ) - >>> pipe = pipe.to("cuda") - - >>> prompt = "make the mountains snowy" - >>> image = pipe(prompt=prompt, image=image).images[0] - ``` - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ - # 0. Check inputs + # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) if image is None: raise ValueError("`image` input cannot be undefined.") - # 1. Define call parameters + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -742,6 +762,7 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @@ -749,8 +770,16 @@ def __call__( # check if scheduler is in sigmas space scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") - # 2. Encode input prompt - prompt_embeds = self._encode_prompt( + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -758,25 +787,32 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, ) - # 3. Preprocess image - image = self.image_processor.preprocess(image) + # 4. Preprocess image + image = self.image_processor.preprocess(image).to(device) - # 4. set timesteps + # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps - # 5. Prepare Image latents - image_latents = self.prepare_image_latents( - image, - batch_size, - num_images_per_prompt, - prompt_embeds.dtype, - device, - do_classifier_free_guidance, - generator, - ) + # 6. Prepare Image latents + # image_latents = self.prepare_image_latents( + # image, + # batch_size, + # num_images_per_prompt, + # prompt_embeds.dtype, + # device, + # do_classifier_free_guidance, + # generator, + # ) + image_latents = self.vae.encode(image).latent_dist.mode() + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([uncond_image_latents, image_latents], dim=0) height, width = image_latents.shape[-2:] height = height * self.vae_scale_factor @@ -795,37 +831,67 @@ def __call__( latents, ) - # 7. Check that shapes of latents and image match the UNet channels + # 8. Check that shapes of latents and image match the UNet channels num_channels_image = image_latents.shape[1] if num_channels_latents + num_channels_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_image`: {num_channels_image} " - f" = {num_channels_latents+num_channels_image}. Please verify the config of" + f" = {num_channels_latents + num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) - # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 9. Denoising loop + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 10. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype=prompt_embeds.dtype, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + # add_text_embeds = torch.concat((add_text_embeds, add_text_embeds.clone()[0].unsqueeze(0)), dim=0) + # add_time_ids = torch.concat((add_time_ids, add_time_ids.clone()[0].unsqueeze(0)), dim=0) + # prompt_embeds = torch.concat((prompt_embeds, prompt_embeds.clone()[0].unsqueeze(0)), dim=0) + + prompt_embeds = prompt_embeds.to(device).to(torch.float32) + add_text_embeds = add_text_embeds.to(device).to(torch.float32) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 11. Denoising loop + self.unet = self.unet.to(torch.float32) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Expand the latents if we are doing classifier free guidance. # The latents are expanded 3 times because for pix2pix the guidance\ # is applied for both the text and the input image. - latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents + # latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # concat latents, image_latents in the channel dimension scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) # predict the noise residual - noise_pred = self.unet( - scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False - )[0] + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + # tar_idx = 2; added_cond_kwargs_tmp = {"text_embeds": add_text_embeds[tar_idx].unsqueeze(0), "time_ids": add_time_ids[tar_idx].unsqueeze(0)} + # self.unet(scaled_latent_model_input[tar_idx].unsqueeze(0), t, encoder_hidden_states=prompt_embeds[tar_idx].unsqueeze(0), cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs_tmp, return_dict=False)[0] + noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False)[0] # Hack: # For karras style schedulers the model does classifer free guidance using the @@ -838,12 +904,12 @@ def __call__( # perform guidance if do_classifier_free_guidance: - noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) - noise_pred = ( - noise_pred_uncond - + guidance_scale * (noise_pred_text - noise_pred_image) - + image_guidance_scale * (noise_pred_image - noise_pred_uncond) - ) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # Hack: # For karras style schedulers the model does classifer free guidance using the @@ -863,392 +929,41 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) + # make sure the VAE is in float32 mode, as it overflows in float16 + self.vae.to(dtype=torch.float32) + + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(latents.dtype) + self.vae.decoder.conv_in.to(latents.dtype) + self.vae.decoder.mid_block.to(latents.dtype) + else: + latents = latents.float() + if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents - has_nsfw_concept = None + return StableDiffusionXLPipelineOutput(images=image) - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] - else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + image = self.watermark.apply_watermark(image) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): - from accelerate import cpu_offload - else: - raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: - cpu_offload(cpu_offloaded_model, device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - - def enable_model_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate import cpu_offload_with_hook - else: - raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - hook = None - for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: - _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - - if self.safety_checker is not None: - _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) - - # We'll offload the last model manually. - self.final_offload_hook = hook - - @property - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_ prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] - prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]) - - return prompt_embeds - - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is None: - has_nsfw_concept = None - else: - if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - else: - feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - return image, has_nsfw_concept + return (image,) - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def decode_latents(self, latents): - warnings.warn( - "The decode_latents method is deprecated and will be removed in a future version. Please" - " use VaeImageProcessor instead", - FutureWarning, - ) - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - def check_inputs( - self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None - ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - def prepare_image_latents( - self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None - ): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - image = image.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - - if image.shape[1] == 4: - image_latents = image - else: - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if isinstance(generator, list): - image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = self.vae.encode(image).latent_dist.mode() - - if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: - # expand image_latents for batch_size - deprecation_message = ( - f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" - " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" - " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" - " your script to pass as many initial images as text prompts to suppress this warning." - ) - deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) - additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) - elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." - ) - else: - image_latents = torch.cat([image_latents], dim=0) - - if do_classifier_free_guidance: - uncond_image_latents = torch.zeros_like(image_latents) - image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) - - return image_latents + return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file From f43c0d5a2621b38a062912210b3559096adf0f59 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Sun, 16 Jul 2023 23:47:06 +1000 Subject: [PATCH 27/67] Support instruction pix2pix sdxl --- .../train_instruct_pix2pix_xl.py | 54 +++++++------ ...ne_stable_diffusion_xl_instruct_pix2pix.py | 75 ++++--------------- 2 files changed, 45 insertions(+), 84 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py index 5c889824dac2..657ca530aa9b 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py @@ -41,6 +41,7 @@ from packaging import version from torchvision import transforms from tqdm.auto import tqdm + # from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import AutoTokenizer, PretrainedConfig @@ -51,7 +52,9 @@ from diffusers.utils import check_min_version, deprecate, is_wandb_available from diffusers.utils.import_utils import is_xformers_available -from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instruct_pix2pix import ( + StableDiffusionXLInstructPix2PixPipeline, +) from PIL import Image @@ -463,9 +466,7 @@ def main(): tokenizer_2 = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False ) - text_encoder_cls_1 = import_model_class_from_model_name_or_path( - args.pretrained_model_name_or_path, args.revision - ) + text_encoder_cls_1 = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) text_encoder_cls_2 = import_model_class_from_model_name_or_path( args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" ) @@ -651,7 +652,11 @@ def load_model_hook(models, input_dir): # We need to tokenize input captions and transform the images. def tokenize_captions(captions, a_tokenizer): inputs = a_tokenizer( - captions, max_length=a_tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + captions, + max_length=a_tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", ) return inputs.input_ids @@ -714,7 +719,7 @@ def collate_fn(examples): "original_pixel_values": original_pixel_values, "edited_pixel_values": edited_pixel_values, "input_ids": input_ids, - "input_ids_2": input_ids_2, + "input_ids_2": input_ids_2, } # DataLoaders creation: @@ -878,11 +883,13 @@ def collate_fn(examples): # Final text conditioning. ### Begin: Get null conditioning null_conditioning_list = [] - for a_tokenizer, a_text_encoder in zip((tokenizer_1, tokenizer_2), (text_encoder_1, text_encoder_2)): + for a_tokenizer, a_text_encoder in zip( + (tokenizer_1, tokenizer_2), (text_encoder_1, text_encoder_2) + ): null_conditioning_list.append( a_text_encoder( - tokenize_captions([""], a_tokenizer=a_tokenizer).to(accelerator.device), - output_hidden_states=True + tokenize_captions([""], a_tokenizer=a_tokenizer).to(accelerator.device), + output_hidden_states=True, ).hidden_states[-2] ) ### End: Get null conditioning @@ -912,7 +919,7 @@ def collate_fn(examples): ### Begin SDXL add_text_embeds = pooled_prompt_embeds - + crops_coords_top_left = (0, 0) target_size = (512, 512) original_size = original_image_embeds.shape[-2:] @@ -922,13 +929,15 @@ def collate_fn(examples): add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=encoder_hidden_states.dtype) add_time_ids = add_time_ids.to(encoder_hidden_states.device).repeat(args.train_batch_size, 1) ### End SDXL - + # Predict the noise residual and compute loss # tmp_prompt_embeds = torch.load('xl_prompt_embeds.pt').to('cuda') # tmp_concatenated_noisy_latents = torch.load('xl_latent_model_input.pt').to('cuda') added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - - model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs).sample + + model_pred = unet( + concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ).sample # model_pred = unet(concatenated_noisy_latents, timesteps, tmp_prompt_embeds, added_cond_kwargs=added_cond_kwargs).sample loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") @@ -983,11 +992,8 @@ def collate_fn(examples): progress_bar.set_postfix(**logs) ### BEGIN: Perform validation every `validation_epochs` steps - if global_step % args.validation_epochs == 0 or global_step == 1: - if ( - (args.val_image_url is not None) - and (args.validation_prompt is not None) - ): + if global_step % args.validation_epochs == 0 or global_step == 1: + if (args.val_image_url is not None) and (args.validation_prompt is not None): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." @@ -998,7 +1004,7 @@ def collate_fn(examples): # Store the UNet parameters temporarily and load the EMA parameters to perform inference. ema_unet.store(unet.parameters()) ema_unet.copy_to(unet.parameters()) - + # The models need unwrapping because for compatibility in distributed training mode. pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, @@ -1015,9 +1021,9 @@ def collate_fn(examples): pipeline.set_progress_bar_config(disable=True) # run inference - # Save validation images + # Save validation images val_save_dir = os.path.join(args.output_dir, "validation_images") - if not os.path.exists(val_save_dir): + if not os.path.exists(val_save_dir): os.makedirs(val_save_dir) # original_image = download_image(args.val_image_url) @@ -1071,9 +1077,9 @@ def collate_fn(examples): pipeline.set_progress_bar_config(disable=True) # run inference - # Save validation images + # Save validation images val_save_dir = os.path.join(args.output_dir, "validation_images") - if not os.path.exists(val_save_dir): + if not os.path.exists(val_save_dir): os.makedirs(val_save_dir) # original_image = download_image(args.val_image_url) @@ -1165,4 +1171,4 @@ def collate_fn(examples): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index e7a36349cbeb..d6d688089d5f 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -1,17 +1,3 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -356,7 +342,7 @@ def encode_prompt( # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - + prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) @@ -455,44 +441,6 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix - def check_inputs( - self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None - ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) @@ -519,7 +467,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents - + # Copy from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix def prepare_image_latents( self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None @@ -528,7 +476,7 @@ def prepare_image_latents( raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - + image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt @@ -540,7 +488,7 @@ def prepare_image_latents( if self.vae.config.force_upcast: image = image.float() self.vae.to(dtype=torch.float32) - + if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -576,7 +524,7 @@ def prepare_image_latents( image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) return image_latents - + def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype ): @@ -748,7 +696,7 @@ def __call__( "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + # self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) if image is None: raise ValueError("`image` input cannot be undefined.") @@ -891,7 +839,14 @@ def __call__( added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # tar_idx = 2; added_cond_kwargs_tmp = {"text_embeds": add_text_embeds[tar_idx].unsqueeze(0), "time_ids": add_time_ids[tar_idx].unsqueeze(0)} # self.unet(scaled_latent_model_input[tar_idx].unsqueeze(0), t, encoder_hidden_states=prompt_embeds[tar_idx].unsqueeze(0), cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs_tmp, return_dict=False)[0] - noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False)[0] + noise_pred = self.unet( + scaled_latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] # Hack: # For karras style schedulers the model does classifer free guidance using the @@ -966,4 +921,4 @@ def __call__( if not return_dict: return (image,) - return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file + return StableDiffusionXLPipelineOutput(images=image) From 83a44761a75acd621fa8fa86675564a719dcae8a Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Sun, 16 Jul 2023 23:49:54 +1000 Subject: [PATCH 28/67] Support instruction pix2pix sdxl --- .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index d6d688089d5f..eb066716fa45 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -19,6 +19,7 @@ from ...utils import ( is_accelerate_available, is_accelerate_version, + deprecate, logging, randn_tensor, replace_example_docstring, From 9aa1e83be4a47e4a7408885d3b8eb1faaabf1c49 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Sun, 16 Jul 2023 23:54:36 +1000 Subject: [PATCH 29/67] Support instruction pix2pix sdxl --- .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index eb066716fa45..7622c2e00d8b 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -17,9 +17,9 @@ ) from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + deprecate, is_accelerate_available, is_accelerate_version, - deprecate, logging, randn_tensor, replace_example_docstring, From c04e813a35e63d504e5333b77a51a98d1eb2e603 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Mon, 17 Jul 2023 00:18:34 +1000 Subject: [PATCH 30/67] Support instruction pix2pix sdxl --- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 28 ++++++++----------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 7622c2e00d8b..60571072b18b 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -4,31 +4,25 @@ import numpy as np import PIL.Image import torch -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import (CLIPTextModel, CLIPTextModelWithProjection, + CLIPTokenizer) from ...image_processor import VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import (FromSingleFileMixin, LoraLoaderMixin, + TextualInversionLoaderMixin) from ...models import AutoencoderKL, UNet2DConditionModel -from ...models.attention_processor import ( - AttnProcessor2_0, - LoRAAttnProcessor2_0, - LoRAXFormersAttnProcessor, - XFormersAttnProcessor, -) +from ...models.attention_processor import (AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor) from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - deprecate, - is_accelerate_available, - is_accelerate_version, - logging, - randn_tensor, - replace_example_docstring, -) +from ...utils import (deprecate, is_accelerate_available, + is_accelerate_version, logging, randn_tensor, + replace_example_docstring) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionXLPipelineOutput from .watermark import StableDiffusionXLWatermarker - logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ From 30fddaf37be28f31798b8b7da59c28c1b9a89598 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Mon, 17 Jul 2023 00:20:26 +1000 Subject: [PATCH 31/67] Support instruction pix2pix sdxl --- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 60571072b18b..7622c2e00d8b 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -4,25 +4,31 @@ import numpy as np import PIL.Image import torch -from transformers import (CLIPTextModel, CLIPTextModelWithProjection, - CLIPTokenizer) +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import VaeImageProcessor -from ...loaders import (FromSingleFileMixin, LoraLoaderMixin, - TextualInversionLoaderMixin) +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel -from ...models.attention_processor import (AttnProcessor2_0, - LoRAAttnProcessor2_0, - LoRAXFormersAttnProcessor, - XFormersAttnProcessor) +from ...models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) from ...schedulers import KarrasDiffusionSchedulers -from ...utils import (deprecate, is_accelerate_available, - is_accelerate_version, logging, randn_tensor, - replace_example_docstring) +from ...utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionXLPipelineOutput from .watermark import StableDiffusionXLWatermarker + logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ From 0b6fcd41289802b7ff47e0567e2fd8467aa822d0 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Mon, 17 Jul 2023 00:21:37 +1000 Subject: [PATCH 32/67] Support instruction pix2pix sdxl --- .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 7622c2e00d8b..38bcc9f2a6b4 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -1,6 +1,5 @@ -import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union - +import inspect import numpy as np import PIL.Image import torch @@ -28,7 +27,6 @@ from . import StableDiffusionXLPipelineOutput from .watermark import StableDiffusionXLWatermarker - logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ From b806da3ef43597cc47ab7e6cc96d5058ccc1ac8e Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Mon, 17 Jul 2023 00:28:33 +1000 Subject: [PATCH 33/67] Support instruction pix2pix sdxl --- .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 38bcc9f2a6b4..7622c2e00d8b 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -1,5 +1,6 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple, Union import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + import numpy as np import PIL.Image import torch @@ -27,6 +28,7 @@ from . import StableDiffusionXLPipelineOutput from .watermark import StableDiffusionXLWatermarker + logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ From 9978715c58a609a7f0e2434d21cc92ce8a28484c Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Mon, 17 Jul 2023 08:48:24 +1000 Subject: [PATCH 34/67] Support instruction pix2pix sdxl --- ...eline_stable_diffusion_xl_instruct_pix2pix.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 7622c2e00d8b..b53f2ccf0021 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -32,22 +32,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers import StableDiffusionXLImg2ImgPipeline - >>> from diffusers.utils import load_image - - >>> pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( - ... "stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16 - ... ) - >>> pipe = pipe.to("cuda") - >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" - - >>> init_image = load_image(url).convert("RGB") - >>> prompt = "a photo of an astronaut riding a horse on mars" - >>> image = pipe(prompt, image=init_image).images[0] - ``` """ From a4f5455e595cd24f281eba46191941a336e66e10 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Mon, 17 Jul 2023 20:23:47 +1000 Subject: [PATCH 35/67] Support instruction pix2pix sdxl --- .../train_instruct_pix2pix_xl.py | 32 ++++--------------- 1 file changed, 6 insertions(+), 26 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py index 657ca530aa9b..7584554cf1d3 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py @@ -1,20 +1,4 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Script to fine-tune Stable Diffusion for InstructPix2Pix.""" +"""Script to train Stable Diffusion XL for InstructPix2Pix.""" import argparse import logging @@ -39,24 +23,20 @@ from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version +from PIL import Image from torchvision import transforms from tqdm.auto import tqdm - -# from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import AutoTokenizer, PretrainedConfig import diffusers -from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler -from diffusers.training_utils import EMAModel -from diffusers.utils import check_min_version, deprecate, is_wandb_available -from diffusers.utils.import_utils import is_xformers_available - from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instruct_pix2pix import ( StableDiffusionXLInstructPix2PixPipeline, ) - -from PIL import Image +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available # Will error if the minimal version of diffusers is not installed. Remove at your own risks. From 3d32d50ce314c17581d67b0d66994f2dd0701908 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Mon, 17 Jul 2023 22:55:57 +1000 Subject: [PATCH 36/67] Clean up IP2P SDXL code --- .../train_instruct_pix2pix_xl.py | 123 ++++++------------ ...ne_stable_diffusion_xl_instruct_pix2pix.py | 112 ++++++++++------ 2 files changed, 115 insertions(+), 120 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py index 7584554cf1d3..3a9ab48198da 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py @@ -5,7 +5,9 @@ import math import os import shutil +import warnings from pathlib import Path +from urllib.parse import urlparse import accelerate import datasets @@ -50,6 +52,14 @@ WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"] +def is_url(string): + try: + result = urlparse(string) + return all([result.scheme, result.netloc]) + except ValueError: + return False + + def import_model_class_from_model_name_or_path( pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" ): @@ -131,7 +141,7 @@ def parse_args(): help="The column of the dataset containing the edit instruction.", ) parser.add_argument( - "--val_image_url", + "--val_image_url_or_path", type=str, default=None, help="URL to the original image that you would like to edit (used during inference for debugging purposes).", @@ -146,11 +156,11 @@ def parse_args(): help="Number of images that should be generated during validation with `validation_prompt`.", ) parser.add_argument( - "--validation_epochs", + "--validation_steps", type=int, default=100, help=( - "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" + "Run fine-tuning validation every X steps. The validation process consists of running the prompt" " `args.validation_prompt` multiple times: `args.num_validation_images`." ), ) @@ -738,8 +748,11 @@ def collate_fn(examples): weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 + warnings.warn(f"weight_dtype {weight_dtype} may cause nan during vae encoding", UserWarning) + elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 + warnings.warn(f"weight_dtype {weight_dtype} may cause nan during vae encoding", UserWarning) # Move text_encode and vae to gpu and cast to weight_dtype text_encoder_1.to(accelerator.device, dtype=weight_dtype) @@ -815,7 +828,6 @@ def collate_fn(examples): # We want to learn the denoising process w.r.t the edited images which # are conditioned on the original image (which was edited) and the edit instruction. # So, first, convert images to latent space. - # tmp_pixel_value = torch.load('xl_image.pt').to('cuda') latents = vae.encode(batch["edited_pixel_values"].to(weight_dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor @@ -840,7 +852,7 @@ def collate_fn(examples): bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, 1, 1) + # prompt_embeds = prompt_embeds.repeat(1, 1, 1) prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) @@ -901,7 +913,7 @@ def collate_fn(examples): add_text_embeds = pooled_prompt_embeds crops_coords_top_left = (0, 0) - target_size = (512, 512) + target_size = (args.resolution, args.resolution) original_size = original_image_embeds.shape[-2:] add_time_ids = list(original_size + crops_coords_top_left + target_size) add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) @@ -911,14 +923,11 @@ def collate_fn(examples): ### End SDXL # Predict the noise residual and compute loss - # tmp_prompt_embeds = torch.load('xl_prompt_embeds.pt').to('cuda') - # tmp_concatenated_noisy_latents = torch.load('xl_latent_model_input.pt').to('cuda') added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} model_pred = unet( concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ).sample - # model_pred = unet(concatenated_noisy_latents, timesteps, tmp_prompt_embeds, added_cond_kwargs=added_cond_kwargs).sample loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Gather the losses across all processes for logging (if we use distributed training). @@ -972,8 +981,8 @@ def collate_fn(examples): progress_bar.set_postfix(**logs) ### BEGIN: Perform validation every `validation_epochs` steps - if global_step % args.validation_epochs == 0 or global_step == 1: - if (args.val_image_url is not None) and (args.validation_prompt is not None): + if global_step % args.validation_steps == 0 or global_step == 1: + if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." @@ -1006,11 +1015,14 @@ def collate_fn(examples): if not os.path.exists(val_save_dir): os.makedirs(val_save_dir) - # original_image = download_image(args.val_image_url) - original_image = Image.open(args.val_image_url).convert("RGB") + if is_url(args.val_image_url_or_path): + original_image = download_image(args.val_image_url_or_path) + else: + original_image = Image.open(args.val_image_url_or_path).convert("RGB") with torch.autocast( str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" ): + edited_images = [] for val_img_idx in range(args.num_validation_images): a_val_img = pipeline( args.validation_prompt, @@ -1020,81 +1032,28 @@ def collate_fn(examples): guidance_scale=7, generator=generator, ).images[0] + edited_images.append(a_val_img) a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png")) + + for tracker in accelerator.trackers: + if tracker.name == "wandb": + wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES) + for edited_image in edited_images: + wandb_table.add_data( + wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt + ) + tracker.log({"validation": wandb_table}) + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(unet.parameters()) + + del pipeline + torch.cuda.empty_cache() ### END: Perform validation every `validation_epochs` steps if global_step >= args.max_train_steps: break - if accelerator.is_main_process: - if ( - (args.val_image_url is not None) - and (args.validation_prompt is not None) - and (epoch % args.validation_epochs == 0) - ): - logger.info( - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" - f" {args.validation_prompt}." - ) - # create pipeline - if args.use_ema: - # Store the UNet parameters temporarily and load the EMA parameters to perform inference. - ema_unet.store(unet.parameters()) - ema_unet.copy_to(unet.parameters()) - # The models need unwrapping because for compatibility in distributed training mode. - pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder_1), - text_encoder_2=accelerator.unwrap_model(text_encoder_2), - tokenizer=accelerator.unwrap_model(tokenizer_1), - tokenizer_2=accelerator.unwrap_model(tokenizer_2), - vae=accelerator.unwrap_model(vae), - revision=args.revision, - torch_dtype=weight_dtype, - ) - pipeline = pipeline.to(accelerator.device) - pipeline.set_progress_bar_config(disable=True) - - # run inference - # Save validation images - val_save_dir = os.path.join(args.output_dir, "validation_images") - if not os.path.exists(val_save_dir): - os.makedirs(val_save_dir) - - # original_image = download_image(args.val_image_url) - original_image = Image.open(args.val_image_url).convert("RGB") - edited_images = [] - with torch.autocast( - str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" - ): - for _ in range(args.num_validation_images): - edited_images.append( - pipeline( - args.validation_prompt, - image=original_image, - num_inference_steps=20, - image_guidance_scale=1.5, - guidance_scale=7, - generator=generator, - ).images[0] - ) - - for tracker in accelerator.trackers: - if tracker.name == "wandb": - wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES) - for edited_image in edited_images: - wandb_table.add_data( - wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt - ) - tracker.log({"validation": wandb_table}) - if args.use_ema: - # Switch back to the original UNet parameters. - ema_unet.restore(unet.parameters()) - - del pipeline - torch.cuda.empty_cache() - # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index b53f2ccf0021..002db6b4f8ab 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -22,7 +22,6 @@ is_accelerate_version, logging, randn_tensor, - replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionXLPipelineOutput @@ -31,9 +30,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -EXAMPLE_DOC_STRING = """ -""" - def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ @@ -51,7 +47,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin): r""" - Pipeline for text-to-image generation using Stable Diffusion. + Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion XL. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) @@ -59,7 +55,6 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] @@ -68,18 +63,26 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of + Frozen text-encoder. Stable Diffusion XL uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ - _optional_components = ["tokenizer", "text_encoder"] def __init__( self, @@ -221,7 +224,6 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, prompt, @@ -435,6 +437,43 @@ def get_timesteps(self, num_inference_steps, strength, device): return timesteps, num_inference_steps - t_start + def check_inputs( + self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) @@ -453,7 +492,6 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents - # Copy from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix def prepare_image_latents( self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None ): @@ -510,6 +548,7 @@ def prepare_image_latents( return image_latents + # Copied from diffusers.src.diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype ): @@ -550,7 +589,6 @@ def _get_add_time_ids( return add_time_ids, add_neg_time_ids @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, @@ -681,7 +719,7 @@ def __call__( "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct - # self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) if image is None: raise ValueError("`image` input cannot be undefined.") @@ -733,25 +771,21 @@ def __call__( timesteps = self.scheduler.timesteps # 6. Prepare Image latents - # image_latents = self.prepare_image_latents( - # image, - # batch_size, - # num_images_per_prompt, - # prompt_embeds.dtype, - # device, - # do_classifier_free_guidance, - # generator, - # ) - image_latents = self.vae.encode(image).latent_dist.mode() - if do_classifier_free_guidance: - uncond_image_latents = torch.zeros_like(image_latents) - image_latents = torch.cat([uncond_image_latents, image_latents], dim=0) + image_latents = self.prepare_image_latents( + image, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + do_classifier_free_guidance, + generator, + ) height, width = image_latents.shape[-2:] height = height * self.vae_scale_factor width = width * self.vae_scale_factor - # 6. Prepare latent variables + # 7. Prepare latent variables num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, @@ -793,13 +827,14 @@ def __call__( ) if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0) + add_text_embeds = torch.cat([add_text_embeds, negative_pooled_prompt_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_neg_time_ids], dim=0) - # add_text_embeds = torch.concat((add_text_embeds, add_text_embeds.clone()[0].unsqueeze(0)), dim=0) - # add_time_ids = torch.concat((add_time_ids, add_time_ids.clone()[0].unsqueeze(0)), dim=0) - # prompt_embeds = torch.concat((prompt_embeds, prompt_embeds.clone()[0].unsqueeze(0)), dim=0) + # Make dimension 3 + add_text_embeds = torch.concat((add_text_embeds, add_text_embeds.clone()[-1].unsqueeze(0)), dim=0) + add_time_ids = torch.concat((add_time_ids, add_time_ids.clone()[-1].unsqueeze(0)), dim=0) + prompt_embeds = torch.concat((prompt_embeds, prompt_embeds.clone()[-1].unsqueeze(0)), dim=0) prompt_embeds = prompt_embeds.to(device).to(torch.float32) add_text_embeds = add_text_embeds.to(device).to(torch.float32) @@ -813,8 +848,7 @@ def __call__( # Expand the latents if we are doing classifier free guidance. # The latents are expanded 3 times because for pix2pix the guidance\ # is applied for both the text and the input image. - # latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents # concat latents, image_latents in the channel dimension scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -822,8 +856,6 @@ def __call__( # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - # tar_idx = 2; added_cond_kwargs_tmp = {"text_embeds": add_text_embeds[tar_idx].unsqueeze(0), "time_ids": add_time_ids[tar_idx].unsqueeze(0)} - # self.unet(scaled_latent_model_input[tar_idx].unsqueeze(0), t, encoder_hidden_states=prompt_embeds[tar_idx].unsqueeze(0), cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs_tmp, return_dict=False)[0] noise_pred = self.unet( scaled_latent_model_input, t, @@ -844,8 +876,12 @@ def __call__( # perform guidance if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_image) + + image_guidance_scale * (noise_pred_image - noise_pred_uncond) + ) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf From f06ed07d0f9752f1cdeda939b76cdd3e3fc136aa Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Mon, 17 Jul 2023 23:02:34 +1000 Subject: [PATCH 37/67] Clean up IP2P SDXL code --- .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 002db6b4f8ab..53b9e4616aef 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -548,7 +548,6 @@ def prepare_image_latents( return image_latents - # Copied from diffusers.src.diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype ): From ab7004e268a29f77fd9a86837bb47adc1e644822 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Tue, 18 Jul 2023 21:26:22 +1000 Subject: [PATCH 38/67] [IP2P and SDXL] clean up code --- .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 53b9e4616aef..4e197e9cd37d 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -206,7 +206,7 @@ def enable_model_cpu_offload(self, gpu_id=0): self.final_offload_hook = hook @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling From aaaef7b991394727152d5e4c0b4a807ba7fe9223 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Tue, 18 Jul 2023 21:33:17 +1000 Subject: [PATCH 39/67] [IP2P and SDXL] clean up code --- .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 4e197e9cd37d..3dc4784ce829 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -206,7 +206,6 @@ def enable_model_cpu_offload(self, gpu_id=0): self.final_offload_hook = hook @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling @@ -941,4 +940,4 @@ def __call__( if not return_dict: return (image,) - return StableDiffusionXLPipelineOutput(images=image) + return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file From 698649ddb7618319af2e129762f73e4970a92a93 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Tue, 18 Jul 2023 21:34:59 +1000 Subject: [PATCH 40/67] [IP2P and SDXL] clean up code --- .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 3dc4784ce829..761fdbd2c5f1 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -940,4 +940,4 @@ def __call__( if not return_dict: return (image,) - return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file + return StableDiffusionXLPipelineOutput(images=image) From a4772e7a7205b36668443281059a959e70b9795a Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Wed, 19 Jul 2023 21:27:19 +1000 Subject: [PATCH 41/67] [IP2P SDXL] Address code reviews --- .../train_instruct_pix2pix_xl.py | 289 ++++++++++++------ 1 file changed, 188 insertions(+), 101 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py index 3a9ab48198da..ced0d44f5839 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py @@ -1,4 +1,18 @@ -"""Script to train Stable Diffusion XL for InstructPix2Pix.""" +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 Harutatsu Akiyama and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import argparse import logging @@ -52,12 +66,25 @@ WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"] -def is_url(string): - try: - result = urlparse(string) - return all([result.scheme, result.netloc]) - except ValueError: - return False +# Load image from URL or path +def create_image_loader(): + def is_url(string): + try: + result = urlparse(string) + return all([result.scheme, result.netloc]) + except ValueError: + return False + + def load_image(image_url_or_path): + if is_url(image_url_or_path): + return download_image(image_url_or_path) + else: + return Image.open(image_url_or_path).convert("RGB") + + return load_image + + +load_image = create_image_loader() def import_model_class_from_model_name_or_path( @@ -81,7 +108,7 @@ def import_model_class_from_model_name_or_path( def parse_args(): - parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.") + parser = argparse.ArgumentParser(description="Script to train Stable Diffusion XL for InstructPix2Pix.") parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -192,9 +219,21 @@ def parse_args(): default=256, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution" + " resolution." ), ) + parser.add_argument( + "--crops_coords_top_left_h", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--crops_coords_top_left_w", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) parser.add_argument( "--center_crop", default=False, @@ -449,27 +488,6 @@ def main(): repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token ).repo_id - # Load scheduler, tokenizer and models. - tokenizer_1 = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False - ) - tokenizer_2 = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False - ) - text_encoder_cls_1 = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) - text_encoder_cls_2 = import_model_class_from_model_name_or_path( - args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" - ) - - # Load scheduler and models - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - text_encoder_1 = text_encoder_cls_1.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision - ) - text_encoder_2 = text_encoder_cls_2.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision - ) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision @@ -495,8 +513,6 @@ def main(): # Freeze vae and text_encoder vae.requires_grad_(False) - text_encoder_1.requires_grad_(False) - text_encoder_2.requires_grad_(False) # Create EMA for the unet. if args.use_ema: @@ -638,6 +654,17 @@ def load_model_hook(models, input_dir): f"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}" ) + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + warnings.warn(f"weight_dtype {weight_dtype} may cause nan during vae encoding", UserWarning) + + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + warnings.warn(f"weight_dtype {weight_dtype} may cause nan during vae encoding", UserWarning) + # Preprocessing the datasets. # We need to tokenize input captions and transform the images. def tokenize_captions(captions, a_tokenizer): @@ -672,6 +699,123 @@ def preprocess_images(examples): images = 2 * (images / 255) - 1 return train_transforms(images) + # Load scheduler, tokenizer and models. + tokenizer_1 = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False + ) + tokenizer_2 = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False + ) + text_encoder_cls_1 = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + text_encoder_cls_2 = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder_1 = text_encoder_cls_1.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + text_encoder_2 = text_encoder_cls_2.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + ) + text_encoder_1.requires_grad_(False) + text_encoder_2.requires_grad_(False) + + # We ALWAYS pre-compute the additional condition embeddings needed for SDXL + # UNet as the model is already big and it uses two text encoders. + text_encoder_1.to(accelerator.device, dtype=weight_dtype) + text_encoder_2.to(accelerator.device, dtype=weight_dtype) + tokenizers = [tokenizer_1, tokenizer_2] + text_encoders = [text_encoder_1, text_encoder_2] + + # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt + def encode_prompt(text_encoders, tokenizers, prompt): + prompt_embeds_list = [] + + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt + def encode_prompts(text_encoders, tokenizers, prompts): + prompt_embeds_all = [] + pooled_prompt_embeds_all = [] + + for prompt in prompts: + prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt) + prompt_embeds_all.append(prompt_embeds) + pooled_prompt_embeds_all.append(pooled_prompt_embeds) + + return torch.stack(prompt_embeds_all), torch.stack(pooled_prompt_embeds_all) + + # Adapted from examples.dreambooth.train_dreambooth_lora_sdxl + # Here, we compute not just the text embeddings but also the additional embeddings + # needed for the SD XL UNet to operate. + def compute_embeddings_for_prompts(prompts, text_encoders, tokenizers): + with torch.no_grad(): + prompt_embeds_all, pooled_prompt_embeds_all = encode_prompts(text_encoders, tokenizers, prompts) + add_text_embeds_all = pooled_prompt_embeds_all + + prompt_embeds_all = prompt_embeds_all.to(accelerator.device) + add_text_embeds_all = add_text_embeds_all.to(accelerator.device) + return prompt_embeds_all, add_text_embeds_all + + # Get null conditioning + def compute_null_conditioning(): + null_conditioning_list = [] + for a_tokenizer, a_text_encoder in zip(tokenizers, text_encoders): + null_conditioning_list.append( + a_text_encoder( + tokenize_captions([""], a_tokenizer=a_tokenizer).to(accelerator.device), + output_hidden_states=True, + ).hidden_states[-2] + ) + return torch.concat(null_conditioning_list, dim=-1) + + null_conditioning = compute_null_conditioning() + + def compute_time_ids(): + crops_coords_top_left = (0, 0) + original_size = target_size = (args.resolution, args.resolution) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids], dtype=weight_dtype) + return add_time_ids.to(accelerator.device).repeat(args.train_batch_size, 1) + + add_time_ids = compute_time_ids() + def preprocess_train(examples): # Preprocess images. preprocessed_images = preprocess_images(examples) @@ -688,8 +832,9 @@ def preprocess_train(examples): # Preprocess the captions. captions = list(examples[edit_prompt_column]) - examples["input_ids"] = tokenize_captions(captions, tokenizer_1) - examples["input_ids_2"] = tokenize_captions(captions, tokenizer_2) + prompt_embeds_all, add_text_embeds_all = compute_embeddings_for_prompts(captions, text_encoders, tokenizers) + examples["prompt_embeds"] = prompt_embeds_all + examples["add_text_embeds"] = add_text_embeds_all return examples with accelerator.main_process_first(): @@ -703,13 +848,13 @@ def collate_fn(examples): original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float() edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples]) edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float() - input_ids = torch.stack([example["input_ids"] for example in examples]) - input_ids_2 = torch.stack([example["input_ids_2"] for example in examples]) + prompt_embeds = torch.concat([example["prompt_embeds"] for example in examples], dim=0) + add_text_embeds = torch.concat([example["add_text_embeds"] for example in examples], dim=0) return { "original_pixel_values": original_pixel_values, "edited_pixel_values": edited_pixel_values, - "input_ids": input_ids, - "input_ids_2": input_ids_2, + "prompt_embeds": prompt_embeds, + "add_text_embeds": add_text_embeds, } # DataLoaders creation: @@ -743,21 +888,7 @@ def collate_fn(examples): if args.use_ema: ema_unet.to(accelerator.device) - # For mixed precision training we cast the text_encoder and vae weights to half-precision - # as these models are only used for inference, keeping weights in full precision is not required. - weight_dtype = torch.float32 - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - warnings.warn(f"weight_dtype {weight_dtype} may cause nan during vae encoding", UserWarning) - - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - warnings.warn(f"weight_dtype {weight_dtype} may cause nan during vae encoding", UserWarning) - - # Move text_encode and vae to gpu and cast to weight_dtype - text_encoder_1.to(accelerator.device, dtype=weight_dtype) - text_encoder_2.to(accelerator.device, dtype=weight_dtype) - text_encoders = [text_encoder_1, text_encoder_2] + # Move vae to gpu and cast to weight_dtype vae.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. @@ -842,24 +973,9 @@ def collate_fn(examples): # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - ### Begin encoder prompt - prompt_embeds_list = [] - for input_ids, text_encoder in zip((batch["input_ids"], batch["input_ids_2"]), text_encoders): - prompt_embeds = text_encoder(input_ids, output_hidden_states=True) - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.hidden_states[-2] - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - # prompt_embeds = prompt_embeds.repeat(1, 1, 1) - prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) - - prompt_embeds_list.append(prompt_embeds) - - encoder_hidden_states = torch.concat(prompt_embeds_list, dim=-1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(bs_embed, -1) - ### End encoder prompt + # SDXL additional inputs + encoder_hidden_states = batch["prompt_embeds"] + add_text_embeds = batch["add_text_embeds"] # Get the additional image embedding for conditioning. # Instead of getting a diagonal Gaussian here, we simply take the mode. @@ -873,19 +989,6 @@ def collate_fn(examples): prompt_mask = random_p < 2 * args.conditioning_dropout_prob prompt_mask = prompt_mask.reshape(bsz, 1, 1) # Final text conditioning. - ### Begin: Get null conditioning - null_conditioning_list = [] - for a_tokenizer, a_text_encoder in zip( - (tokenizer_1, tokenizer_2), (text_encoder_1, text_encoder_2) - ): - null_conditioning_list.append( - a_text_encoder( - tokenize_captions([""], a_tokenizer=a_tokenizer).to(accelerator.device), - output_hidden_states=True, - ).hidden_states[-2] - ) - ### End: Get null conditioning - null_conditioning = torch.concat(null_conditioning_list, dim=-1) encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states) # Sample masks for the original images. @@ -909,19 +1012,6 @@ def collate_fn(examples): else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - ### Begin SDXL - add_text_embeds = pooled_prompt_embeds - - crops_coords_top_left = (0, 0) - target_size = (args.resolution, args.resolution) - original_size = original_image_embeds.shape[-2:] - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) - add_time_ids = torch.tensor([add_time_ids], dtype=encoder_hidden_states.dtype) - add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=encoder_hidden_states.dtype) - add_time_ids = add_time_ids.to(encoder_hidden_states.device).repeat(args.train_batch_size, 1) - ### End SDXL - # Predict the noise residual and compute loss added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} @@ -1015,10 +1105,7 @@ def collate_fn(examples): if not os.path.exists(val_save_dir): os.makedirs(val_save_dir) - if is_url(args.val_image_url_or_path): - original_image = download_image(args.val_image_url_or_path) - else: - original_image = Image.open(args.val_image_url_or_path).convert("RGB") + original_image = load_image(args.val_image_url_or_path) with torch.autocast( str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" ): From 79c8f6150e39c3e6157074400062018ad8541070 Mon Sep 17 00:00:00 2001 From: Harutatsu Akiyama Date: Fri, 21 Jul 2023 16:07:27 +1000 Subject: [PATCH 42/67] [IP2P SDXL] Address code reviews, add docs, tests --- docs/source/en/training/instructpix2pix.mdx | 45 ++++++- .../train_instruct_pix2pix_xl.py | 33 ++--- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 45 ++++--- ...stable_diffusion_xl_instruction_pix2pix.py | 125 ++++++++++++++++++ 4 files changed, 203 insertions(+), 45 deletions(-) create mode 100644 tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py diff --git a/docs/source/en/training/instructpix2pix.mdx b/docs/source/en/training/instructpix2pix.mdx index 03ba8f5635d6..6898163b411f 100644 --- a/docs/source/en/training/instructpix2pix.mdx +++ b/docs/source/en/training/instructpix2pix.mdx @@ -1,4 +1,4 @@ -