diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index df41854a9fe7..de33ba616d0a 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -74,6 +74,8 @@ title: ControlNet - local: training/instructpix2pix title: InstructPix2Pix Training + - local: training/custom_diffusion + title: Custom Diffusion title: Training - sections: - local: using-diffusers/rl diff --git a/docs/source/en/training/custom_diffusion.mdx b/docs/source/en/training/custom_diffusion.mdx new file mode 100644 index 000000000000..1e1958e1c946 --- /dev/null +++ b/docs/source/en/training/custom_diffusion.mdx @@ -0,0 +1,287 @@ + + +# Custom Diffusion training example + +[Custom Diffusion](https://arxiv.org/abs/2212.04488) is a method to customize text-to-image models like Stable Diffusion given just a few (4~5) images of a subject. +The `train_custom_diffusion.py` script shows how to implement the training procedure and adapt it for stable diffusion. + +## Running locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the example folder and run + +```bash +pip install -r requirements.txt +pip install clip-retrieval +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Or if your environment doesn't support an interactive shell e.g. a notebook + +```python +from accelerate.utils import write_basic_config + +write_basic_config() +``` +### Cat example 😺 + +Now let's get our dataset. Download dataset from [here](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip) and unzip it. + +We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`. +The `class_prompt` should be the category name same as target image. The collected real images are with text captions similar to the `class_prompt`. The retrieved image are saved in `class_data_dir`. You can disable `real_prior` to use generated images as regularization. To collect the real images use this command first before training. + +```bash +pip install clip-retrieval +python retrieve.py --class_prompt cat --class_data_dir real_reg/samples_cat --num_class_images 200 +``` + +**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export OUTPUT_DIR="path-to-save-model" +export INSTANCE_DIR="./data/cat" + +accelerate launch train_custom_diffusion.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --class_data_dir=./real_reg/samples_cat/ \ + --with_prior_preservation --real_prior --prior_loss_weight=1.0 \ + --class_prompt="cat" --num_class_images=200 \ + --instance_prompt="photo of a cat" \ + --resolution=512 \ + --train_batch_size=2 \ + --learning_rate=1e-5 \ + --lr_warmup_steps=0 \ + --max_train_steps=250 \ + --scale_lr --hflip \ + --modifier_token "" +``` + +**Use `--enable_xformers_memory_efficient_attention` for faster training with lower VRAM requirement (16GB per GPU). Follow [this guide](https://github.com/facebookresearch/xformers) for installation instructions.** + +To track your experiments using Weights and Biases (`wandb`) and to save intermediate results (whcih we HIGHLY recommend), follow these steps: + +* Install `wandb`: `pip install wandb`. +* Authorize: `wandb login`. +* Then specify a `validation_prompt` and set `report_to` to `wandb` while launching training. You can also configure the following related arguments: + * `num_validation_images` + * `validation_steps` + +Here is an example command: + +```bash +accelerate launch train_custom_diffusion.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --class_data_dir=./real_reg/samples_cat/ \ + --with_prior_preservation --real_prior --prior_loss_weight=1.0 \ + --class_prompt="cat" --num_class_images=200 \ + --instance_prompt="photo of a cat" \ + --resolution=512 \ + --train_batch_size=2 \ + --learning_rate=1e-5 \ + --lr_warmup_steps=0 \ + --max_train_steps=250 \ + --scale_lr --hflip \ + --modifier_token "" \ + --validation_prompt=" cat sitting in a bucket" \ + --report_to="wandb" +``` + +Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/26ghrcau) where you can check out the intermediate results along with other training details. + +If you specify `--push_to_hub`, the learned parameters will be pushed to a repository on the Hugging Face Hub. Here is an [example repository](https://huggingface.co/sayakpaul/custom-diffusion-cat). + +### Training on multiple concepts 🐱🪵 + +Provide a [json](https://github.com/adobe-research/custom-diffusion/blob/main/assets/concept_list.json) file with the info about each concept, similar to [this](https://github.com/ShivamShrirao/diffusers/blob/main/examples/dreambooth/train_dreambooth.py). + +To collect the real images run this command for each concept in the json file. + +```bash +pip install clip-retrieval +python retrieve.py --class_prompt {} --class_data_dir {} --num_class_images 200 +``` + +And then we're ready to start training! + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export OUTPUT_DIR="path-to-save-model" + +accelerate launch train_custom_diffusion.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --output_dir=$OUTPUT_DIR \ + --concepts_list=./concept_list.json \ + --with_prior_preservation --real_prior --prior_loss_weight=1.0 \ + --resolution=512 \ + --train_batch_size=2 \ + --learning_rate=1e-5 \ + --lr_warmup_steps=0 \ + --max_train_steps=500 \ + --num_class_images=200 \ + --scale_lr --hflip \ + --modifier_token "+" +``` + +Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/3990tzkg) where you can check out the intermediate results along with other training details. + +### Training on human faces + +For fine-tuning on human faces we found the following configuration to work better: `learning_rate=5e-6`, `max_train_steps=1000 to 2000`, and `freeze_model=crossattn` with at least 15-20 images. + +To collect the real images use this command first before training. + +```bash +pip install clip-retrieval +python retrieve.py --class_prompt person --class_data_dir real_reg/samples_person --num_class_images 200 +``` + +Then start training! + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export OUTPUT_DIR="path-to-save-model" +export INSTANCE_DIR="path-to-images" + +accelerate launch train_custom_diffusion.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --class_data_dir=./real_reg/samples_person/ \ + --with_prior_preservation --real_prior --prior_loss_weight=1.0 \ + --class_prompt="person" --num_class_images=200 \ + --instance_prompt="photo of a person" \ + --resolution=512 \ + --train_batch_size=2 \ + --learning_rate=5e-6 \ + --lr_warmup_steps=0 \ + --max_train_steps=1000 \ + --scale_lr --hflip --noaug \ + --freeze_model crossattn \ + --modifier_token "" \ + --enable_xformers_memory_efficient_attention +``` + +## Inference + +Once you have trained a model using the above command, you can run inference using the below command. Make sure to include the `modifier token` (e.g. \ in above example) in your prompt. + +```python +import torch +from diffusers import DiffusionPipeline + +pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16).to("cuda") +pipe.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin") +pipe.load_textual_inversion("path-to-save-model", weight_name=".bin") + +image = pipe( + " cat sitting in a bucket", + num_inference_steps=100, + guidance_scale=6.0, + eta=1.0, +).images[0] +image.save("cat.png") +``` + +It's possible to directly load these parameters from a Hub repository: + +```python +import torch +from huggingface_hub.repocard import RepoCard +from diffusers import DiffusionPipeline + +model_id = "sayakpaul/custom-diffusion-cat" +card = RepoCard.load(model_id) +base_model_id = card.data.to_dict()["base_model"] + +pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda") +pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin") +pipe.load_textual_inversion(model_id, weight_name=".bin") + +image = pipe( + " cat sitting in a bucket", + num_inference_steps=100, + guidance_scale=6.0, + eta=1.0, +).images[0] +image.save("cat.png") +``` + +Here is an example of performing inference with multiple concepts: + +```python +import torch +from huggingface_hub.repocard import RepoCard +from diffusers import DiffusionPipeline + +model_id = "sayakpaul/custom-diffusion-cat-wooden-pot" +card = RepoCard.load(model_id) +base_model_id = card.data.to_dict()["base_model"] + +pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda") +pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin") +pipe.load_textual_inversion(model_id, weight_name=".bin") +pipe.load_textual_inversion(model_id, weight_name=".bin") + +image = pipe( + "the cat sculpture in the style of a wooden pot", + num_inference_steps=100, + guidance_scale=6.0, + eta=1.0, +).images[0] +image.save("multi-subject.png") +``` + +Here, `cat` and `wooden pot` refer to the multiple concepts. + +### Inference from a training checkpoint + +You can also perform inference from one of the complete checkpoint saved during the training process, if you used the `--checkpointing_steps` argument. + +TODO. + +## Set grads to none +To save even more memory, pass the `--set_grads_to_none` argument to the script. This will set grads to None instead of zero. However, be aware that it changes certain behaviors, so if you start experiencing any problems, remove this argument. + +More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html + +## Experimental results +You can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail. diff --git a/docs/source/en/training/overview.mdx b/docs/source/en/training/overview.mdx index 5ad3a1f06cc1..c5cea3bb0a96 100644 --- a/docs/source/en/training/overview.mdx +++ b/docs/source/en/training/overview.mdx @@ -39,6 +39,8 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie - [Dreambooth](./dreambooth) - [LoRA Support](./lora) - [ControlNet](./controlnet) +- [InstructPix2Pix](./instructpix2pix) +- [Custom Diffusion](./custom_diffusion) If possible, please [install xFormers](../optimization/xformers) for memory efficient attention. This could help make your training faster and less memory intensive. @@ -50,6 +52,8 @@ If possible, please [install xFormers](../optimization/xformers) for memory effi | [**Dreambooth**](./dreambooth) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb) | [**Training with LoRA**](./lora) | ✅ | - | - | | [**ControlNet**](./controlnet) | ✅ | ✅ | - | +| [**InstructPix2Pix**](./instructpix2pix) | ✅ | ✅ | - | +| [**Custom Diffusion**](./custom_diffusion) | ✅ | ✅ | - | ## Community diff --git a/examples/custom_diffusion/README.md b/examples/custom_diffusion/README.md new file mode 100644 index 000000000000..ecd972737bc3 --- /dev/null +++ b/examples/custom_diffusion/README.md @@ -0,0 +1,280 @@ +# Custom Diffusion training example + +[Custom Diffusion](https://arxiv.org/abs/2212.04488) is a method to customize text-to-image models like Stable Diffusion given just a few (4~5) images of a subject. +The `train_custom_diffusion.py` script shows how to implement the training procedure and adapt it for stable diffusion. + +## Running locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the example folder and run + +```bash +pip install -r requirements.txt +pip install clip-retrieval +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Or if your environment doesn't support an interactive shell e.g. a notebook + +```python +from accelerate.utils import write_basic_config +write_basic_config() +``` +### Cat example 😺 + +Now let's get our dataset. Download dataset from [here](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip) and unzip it. + +We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`. +The `class_prompt` should be the category name same as target image. The collected real images are with text captions similar to the `class_prompt`. The retrieved image are saved in `class_data_dir`. You can disable `real_prior` to use generated images as regularization. To collect the real images use this command first before training. + +```bash +pip install clip-retrieval +python retrieve.py --class_prompt cat --class_data_dir real_reg/samples_cat --num_class_images 200 +``` + +**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export OUTPUT_DIR="path-to-save-model" +export INSTANCE_DIR="./data/cat" + +accelerate launch train_custom_diffusion.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --class_data_dir=./real_reg/samples_cat/ \ + --with_prior_preservation --real_prior --prior_loss_weight=1.0 \ + --class_prompt="cat" --num_class_images=200 \ + --instance_prompt="photo of a cat" \ + --resolution=512 \ + --train_batch_size=2 \ + --learning_rate=1e-5 \ + --lr_warmup_steps=0 \ + --max_train_steps=250 \ + --scale_lr --hflip \ + --modifier_token "" +``` + +**Use `--enable_xformers_memory_efficient_attention` for faster training with lower VRAM requirement (16GB per GPU). Follow [this guide](https://github.com/facebookresearch/xformers) for installation instructions.** + +To track your experiments using Weights and Biases (`wandb`) and to save intermediate results (whcih we HIGHLY recommend), follow these steps: + +* Install `wandb`: `pip install wandb`. +* Authorize: `wandb login`. +* Then specify a `validation_prompt` and set `report_to` to `wandb` while launching training. You can also configure the following related arguments: + * `num_validation_images` + * `validation_steps` + +Here is an example command: + +```bash +accelerate launch train_custom_diffusion.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --class_data_dir=./real_reg/samples_cat/ \ + --with_prior_preservation --real_prior --prior_loss_weight=1.0 \ + --class_prompt="cat" --num_class_images=200 \ + --instance_prompt="photo of a cat" \ + --resolution=512 \ + --train_batch_size=2 \ + --learning_rate=1e-5 \ + --lr_warmup_steps=0 \ + --max_train_steps=250 \ + --scale_lr --hflip \ + --modifier_token "" \ + --validation_prompt=" cat sitting in a bucket" \ + --report_to="wandb" +``` + +Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/26ghrcau) where you can check out the intermediate results along with other training details. + +If you specify `--push_to_hub`, the learned parameters will be pushed to a repository on the Hugging Face Hub. Here is an [example repository](https://huggingface.co/sayakpaul/custom-diffusion-cat). + +### Training on multiple concepts 🐱🪵 + +Provide a [json](https://github.com/adobe-research/custom-diffusion/blob/main/assets/concept_list.json) file with the info about each concept, similar to [this](https://github.com/ShivamShrirao/diffusers/blob/main/examples/dreambooth/train_dreambooth.py). + +To collect the real images run this command for each concept in the json file. + +```bash +pip install clip-retrieval +python retrieve.py --class_prompt {} --class_data_dir {} --num_class_images 200 +``` + +And then we're ready to start training! + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export OUTPUT_DIR="path-to-save-model" + +accelerate launch train_custom_diffusion.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --output_dir=$OUTPUT_DIR \ + --concepts_list=./concept_list.json \ + --with_prior_preservation --real_prior --prior_loss_weight=1.0 \ + --resolution=512 \ + --train_batch_size=2 \ + --learning_rate=1e-5 \ + --lr_warmup_steps=0 \ + --max_train_steps=500 \ + --num_class_images=200 \ + --scale_lr --hflip \ + --modifier_token "+" +``` + +Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/3990tzkg) where you can check out the intermediate results along with other training details. + +### Training on human faces + +For fine-tuning on human faces we found the following configuration to work better: `learning_rate=5e-6`, `max_train_steps=1000 to 2000`, and `freeze_model=crossattn` with at least 15-20 images. + +To collect the real images use this command first before training. + +```bash +pip install clip-retrieval +python retrieve.py --class_prompt person --class_data_dir real_reg/samples_person --num_class_images 200 +``` + +Then start training! + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export OUTPUT_DIR="path-to-save-model" +export INSTANCE_DIR="path-to-images" + +accelerate launch train_custom_diffusion.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --class_data_dir=./real_reg/samples_person/ \ + --with_prior_preservation --real_prior --prior_loss_weight=1.0 \ + --class_prompt="person" --num_class_images=200 \ + --instance_prompt="photo of a person" \ + --resolution=512 \ + --train_batch_size=2 \ + --learning_rate=5e-6 \ + --lr_warmup_steps=0 \ + --max_train_steps=1000 \ + --scale_lr --hflip --noaug \ + --freeze_model crossattn \ + --modifier_token "" \ + --enable_xformers_memory_efficient_attention +``` + +## Inference + +Once you have trained a model using the above command, you can run inference using the below command. Make sure to include the `modifier token` (e.g. \ in above example) in your prompt. + +```python +import torch +from diffusers import DiffusionPipeline + +pipe = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 +).to("cuda") +pipe.unet.load_attn_procs( + "path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin" +) +pipe.load_textual_inversion("path-to-save-model", weight_name=".bin") + +image = pipe( + " cat sitting in a bucket", + num_inference_steps=100, + guidance_scale=6.0, + eta=1.0, +).images[0] +image.save("cat.png") +``` + +It's possible to directly load these parameters from a Hub repository: + +```python +import torch +from huggingface_hub.repocard import RepoCard +from diffusers import DiffusionPipeline + +model_id = "sayakpaul/custom-diffusion-cat" +card = RepoCard.load(model_id) +base_model_id = card.data.to_dict()["base_model"] + +pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to( +"cuda") +pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin") +pipe.load_textual_inversion(model_id, weight_name=".bin") + +image = pipe( + " cat sitting in a bucket", + num_inference_steps=100, + guidance_scale=6.0, + eta=1.0, +).images[0] +image.save("cat.png") +``` + +Here is an example of performing inference with multiple concepts: + +```python +import torch +from huggingface_hub.repocard import RepoCard +from diffusers import DiffusionPipeline + +model_id = "sayakpaul/custom-diffusion-cat-wooden-pot" +card = RepoCard.load(model_id) +base_model_id = card.data.to_dict()["base_model"] + +pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to( +"cuda") +pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin") +pipe.load_textual_inversion(model_id, weight_name=".bin") +pipe.load_textual_inversion(model_id, weight_name=".bin") + +image = pipe( + "the cat sculpture in the style of a wooden pot", + num_inference_steps=100, + guidance_scale=6.0, + eta=1.0, +).images[0] +image.save("multi-subject.png") +``` + +Here, `cat` and `wooden pot` refer to the multiple concepts. + +### Inference from a training checkpoint + +You can also perform inference from one of the complete checkpoint saved during the training process, if you used the `--checkpointing_steps` argument. + +TODO. + +## Set grads to none +To save even more memory, pass the `--set_grads_to_none` argument to the script. This will set grads to None instead of zero. However, be aware that it changes certain behaviors, so if you start experiencing any problems, remove this argument. + +More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html + +## Experimental results +You can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail. \ No newline at end of file diff --git a/examples/custom_diffusion/requirements.txt b/examples/custom_diffusion/requirements.txt new file mode 100644 index 000000000000..7d93f3d03bd8 --- /dev/null +++ b/examples/custom_diffusion/requirements.txt @@ -0,0 +1,6 @@ +accelerate +torchvision +transformers>=4.25.1 +ftfy +tensorboard +Jinja2 diff --git a/examples/custom_diffusion/retrieve.py b/examples/custom_diffusion/retrieve.py new file mode 100644 index 000000000000..7b7635c1887d --- /dev/null +++ b/examples/custom_diffusion/retrieve.py @@ -0,0 +1,87 @@ +# Copyright 2023 Custom Diffusion authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from io import BytesIO +from pathlib import Path + +import requests +from clip_retrieval.clip_client import ClipClient +from PIL import Image +from tqdm import tqdm + + +def retrieve(class_prompt, class_data_dir, num_class_images): + factor = 1.5 + num_images = int(factor * num_class_images) + client = ClipClient( + url="https://knn.laion.ai/knn-service", indice_name="laion_400m", num_images=num_images, aesthetic_weight=0.1 + ) + + os.makedirs(f"{class_data_dir}/images", exist_ok=True) + if len(list(Path(f"{class_data_dir}/images").iterdir())) >= num_class_images: + return + + while True: + class_images = client.query(text=class_prompt) + if len(class_images) >= factor * num_class_images or num_images > 1e4: + break + else: + num_images = int(factor * num_images) + client = ClipClient( + url="https://knn.laion.ai/knn-service", + indice_name="laion_400m", + num_images=num_images, + aesthetic_weight=0.1, + ) + + count = 0 + total = 0 + pbar = tqdm(desc="downloading real regularization images", total=num_class_images) + + with open(f"{class_data_dir}/caption.txt", "w") as f1, open(f"{class_data_dir}/urls.txt", "w") as f2, open( + f"{class_data_dir}/images.txt", "w" + ) as f3: + while total < num_class_images: + images = class_images[count] + count += 1 + try: + img = requests.get(images["url"]) + if img.status_code == 200: + _ = Image.open(BytesIO(img.content)) + with open(f"{class_data_dir}/images/{total}.jpg", "wb") as f: + f.write(img.content) + f1.write(images["caption"] + "\n") + f2.write(images["url"] + "\n") + f3.write(f"{class_data_dir}/images/{total}.jpg" + "\n") + total += 1 + pbar.update(1) + else: + continue + except Exception: + continue + return + + +def parse_args(): + parser = argparse.ArgumentParser("", add_help=False) + parser.add_argument("--class_prompt", help="text prompt to retrieve images", required=True, type=str) + parser.add_argument("--class_data_dir", help="path to save images", required=True, type=str) + parser.add_argument("--num_class_images", help="number of images to download", default=200, type=int) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + retrieve(args.class_prompt, args.class_data_dir, args.num_class_images) diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py new file mode 100644 index 000000000000..49b05e6b5db3 --- /dev/null +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -0,0 +1,1289 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 Custom Diffusion authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import hashlib +import itertools +import json +import logging +import math +import os +import random +import warnings +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import HfApi, create_repo +from packaging import version +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + UNet2DConditionModel, +) +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.15.0.dev0") + +logger = get_logger(__name__) + + +def freeze_params(params): + for param in params: + param.requires_grad = False + + +def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +instance_prompt: {prompt} +tags: +- stable-diffusion +- stable-diffusion-diffusers +- text-to-image +- diffusers +- custom-diffusion +inference: true +--- + """ + model_card = f""" +# Custom Diffusion - {repo_id} + +These are Custom Diffusion adaption weights for {base_model}. The weights were trained on {prompt} using [Custom Diffusion](https://www.cs.cmu.edu/~custom-diffusion). You can find some example images in the following. \n +{img_str} + +\nFor more details on the training, please follow [this link](https://github.com/huggingface/diffusers/blob/main/examples/custom_diffusion). +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def collate_fn(examples, with_prior_preservation): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + mask = [example["mask"] for example in examples] + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + mask += [example["class_mask"] for example in examples] + + input_ids = torch.cat(input_ids, dim=0) + pixel_values = torch.stack(pixel_values) + mask = torch.stack(mask) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + mask = mask.to(memory_format=torch.contiguous_format).float() + + batch = {"input_ids": input_ids, "pixel_values": pixel_values, "mask": mask.unsqueeze(1)} + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +class CustomDiffusionDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + concepts_list, + tokenizer, + size=512, + mask_size=64, + center_crop=False, + with_prior_preservation=False, + num_class_images=200, + hflip=False, + aug=True, + ): + self.size = size + self.mask_size = mask_size + self.center_crop = center_crop + self.tokenizer = tokenizer + self.interpolation = Image.BILINEAR + self.aug = aug + + self.instance_images_path = [] + self.class_images_path = [] + self.with_prior_preservation = with_prior_preservation + for concept in concepts_list: + inst_img_path = [ + (x, concept["instance_prompt"]) for x in Path(concept["instance_data_dir"]).iterdir() if x.is_file() + ] + self.instance_images_path.extend(inst_img_path) + + if with_prior_preservation: + class_data_root = Path(concept["class_data_dir"]) + if os.path.isdir(class_data_root): + class_images_path = list(class_data_root.iterdir()) + class_prompt = [concept["class_prompt"] for _ in range(len(class_images_path))] + else: + with open(class_data_root, "r") as f: + class_images_path = f.read().splitlines() + with open(concept["class_prompt"], "r") as f: + class_prompt = f.read().splitlines() + + class_img_path = [(x, y) for (x, y) in zip(class_images_path, class_prompt)] + self.class_images_path.extend(class_img_path[:num_class_images]) + + random.shuffle(self.instance_images_path) + self.num_instance_images = len(self.instance_images_path) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.flip = transforms.RandomHorizontalFlip(0.5 * hflip) + + self.image_transforms = transforms.Compose( + [ + self.flip, + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def preprocess(self, image, scale, resample): + outer, inner = self.size, scale + factor = self.size // self.mask_size + if scale > self.size: + outer, inner = scale, self.size + top, left = np.random.randint(0, outer - inner + 1), np.random.randint(0, outer - inner + 1) + image = image.resize((scale, scale), resample=resample) + image = np.array(image).astype(np.uint8) + image = (image / 127.5 - 1.0).astype(np.float32) + instance_image = np.zeros((self.size, self.size, 3), dtype=np.float32) + mask = np.zeros((self.size // factor, self.size // factor)) + if scale > self.size: + instance_image = image[top : top + inner, left : left + inner, :] + mask = np.ones((self.size // factor, self.size // factor)) + else: + instance_image[top : top + inner, left : left + inner, :] = image + mask[ + top // factor + 1 : (top + scale) // factor - 1, left // factor + 1 : (left + scale) // factor - 1 + ] = 1.0 + return instance_image, mask + + def __getitem__(self, index): + example = {} + instance_image, instance_prompt = self.instance_images_path[index % self.num_instance_images] + instance_image = Image.open(instance_image) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + instance_image = self.flip(instance_image) + + # apply resize augmentation and create a valid image region mask + random_scale = self.size + if self.aug: + random_scale = ( + np.random.randint(self.size // 3, self.size + 1) + if np.random.uniform() < 0.66 + else np.random.randint(int(1.2 * self.size), int(1.4 * self.size)) + ) + instance_image, mask = self.preprocess(instance_image, random_scale, self.interpolation) + + if random_scale < 0.6 * self.size: + instance_prompt = np.random.choice(["a far away ", "very small "]) + instance_prompt + elif random_scale > self.size: + instance_prompt = np.random.choice(["zoomed in ", "close up "]) + instance_prompt + + example["instance_images"] = torch.from_numpy(instance_image).permute(2, 0, 1) + example["mask"] = torch.from_numpy(mask) + example["instance_prompt_ids"] = self.tokenizer( + instance_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + if self.with_prior_preservation: + class_image, class_prompt = self.class_images_path[index % self.num_class_images] + class_image = Image.open(class_image) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_mask"] = torch.ones_like(example["mask"]) + example["class_prompt_ids"] = self.tokenizer( + class_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + return example + + +def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_dir): + """Saves the new token embeddings from the text encoder.""" + logger.info("Saving embeddings") + learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight + for x, y in zip(modifier_token_id, args.modifier_token): + learned_embeds_dict = {} + learned_embeds_dict[y] = learned_embeds[x] + torch.save(learned_embeds_dict, f"{output_dir}/{y}.bin") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Custom Diffusion training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=2, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument( + "--real_prior", + default=False, + action="store_true", + help="real images as prior.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=200, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="custom-diffusion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=250, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=( + "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." + " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" + " for more docs" + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=2, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--freeze_model", + type=str, + default="crossattn_kv", + choices=["crossattn_kv", "crossattn"], + help="crossattn to enable fine-tuning of all params in the cross attention", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument( + "--concepts_list", + type=str, + default=None, + help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.", + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--modifier_token", + type=str, + default=None, + help="A token to use as a modifier for the concept.", + ) + parser.add_argument( + "--initializer_token", type=str, default="ktn+pll+ucd", help="A token to use as initializer word." + ) + parser.add_argument("--hflip", action="store_true", help="Apply horizontal flip data augmentation.") + parser.add_argument( + "--noaug", + action="store_true", + help="Dont apply augmentation during data augmentation when this flag is enabled.", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.concepts_list is None: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + logging_dir=logging_dir, + project_config=accelerator_project_config, + ) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("custom-diffusion", config=vars(args)) + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + if args.concepts_list is None: + args.concepts_list = [ + { + "instance_prompt": args.instance_prompt, + "class_prompt": args.class_prompt, + "instance_data_dir": args.instance_data_dir, + "class_data_dir": args.class_data_dir, + } + ] + else: + with open(args.concepts_list, "r") as f: + args.concepts_list = json.load(f) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + for i, concept in enumerate(args.concepts_list): + class_images_dir = Path(concept["class_data_dir"]) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True, exist_ok=True) + if args.real_prior: + assert ( + class_images_dir / "images" + ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" + assert ( + len(list((class_images_dir / "images").iterdir())) == args.num_class_images + ), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" + assert ( + class_images_dir / "caption.txt" + ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" + assert ( + class_images_dir / "images.txt" + ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" + concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt") + concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt") + args.concepts_list[i] = concept + accelerator.wait_for_everyone() + else: + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, + desc="Generating class images", + disable=not accelerator.is_local_main_process, + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = ( + class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + ) + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name, + revision=args.revision, + use_fast=False, + ) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + + # Adding a modifier token which is optimized #### + # Code taken from https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py + modifier_token_id = [] + initializer_token_id = [] + if args.modifier_token is not None: + args.modifier_token = args.modifier_token.split("+") + args.initializer_token = args.initializer_token.split("+") + if len(args.modifier_token) > len(args.initializer_token): + raise ValueError("You must specify + separated initializer token for each modifier token.") + for modifier_token, initializer_token in zip( + args.modifier_token, args.initializer_token[: len(args.modifier_token)] + ): + # Add the placeholder token in tokenizer + num_added_tokens = tokenizer.add_tokens(modifier_token) + if num_added_tokens == 0: + raise ValueError( + f"The tokenizer already contains the token {modifier_token}. Please pass a different" + " `modifier_token` that is not already in the tokenizer." + ) + + # Convert the initializer_token, placeholder_token to ids + token_ids = tokenizer.encode([initializer_token], add_special_tokens=False) + print(token_ids) + # Check if initializer_token is a single token or a sequence of tokens + if len(token_ids) > 1: + raise ValueError("The initializer token must be a single token.") + + initializer_token_id.append(token_ids[0]) + modifier_token_id.append(tokenizer.convert_tokens_to_ids(modifier_token)) + + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + for x, y in zip(modifier_token_id, initializer_token_id): + token_embeds[x] = token_embeds[y] + + # Freeze all parameters except for the token embeddings in text encoder + params_to_freeze = itertools.chain( + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + text_encoder.text_model.embeddings.position_embedding.parameters(), + ) + freeze_params(params_to_freeze) + ######################################################## + ######################################################## + + vae.requires_grad_(False) + if args.modifier_token is None: + text_encoder.requires_grad_(False) + unet.requires_grad_(False) + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + if accelerator.mixed_precision != "fp16" and args.modifier_token is not None: + text_encoder.to(accelerator.device, dtype=weight_dtype) + unet.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + attention_class = CustomDiffusionAttnProcessor + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + attention_class = CustomDiffusionXFormersAttnProcessor + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # now we will add new Custom Diffusion weights to the attention layers + # It's important to realize here how many attention weights will be added and of which sizes + # The sizes of the attention layers consist only of two different variables: + # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. + # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. + + # Let's first see how many attention processors we will have to set. + # For Stable Diffusion, it should be equal to: + # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 + # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 + # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18 + # => 32 layers + + # Only train key, value projection layers if freeze_model = 'crossattn_kv' else train all params in the cross attention layer + train_kv = True + train_q_out = False if args.freeze_model == "crossattn_kv" else True + custom_diffusion_attn_procs = {} + + st = unet.state_dict() + for name, _ in unet.attn_processors.items(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + layer_name = name.split(".processor")[0] + weights = { + "to_k_custom_diffusion.weight": st[layer_name + ".to_k.weight"], + "to_v_custom_diffusion.weight": st[layer_name + ".to_v.weight"], + } + if train_q_out: + weights["to_q_custom_diffusion.weight"] = st[layer_name + ".to_q.weight"] + weights["to_out_custom_diffusion.0.weight"] = st[layer_name + ".to_out.0.weight"] + weights["to_out_custom_diffusion.0.bias"] = st[layer_name + ".to_out.0.bias"] + if cross_attention_dim is not None: + custom_diffusion_attn_procs[name] = attention_class( + train_kv=train_kv, + train_q_out=train_q_out, + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + ).to(unet.device) + custom_diffusion_attn_procs[name].load_state_dict(weights) + else: + custom_diffusion_attn_procs[name] = attention_class( + train_kv=False, + train_q_out=False, + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + ) + del st + unet.set_attn_processor(custom_diffusion_attn_procs) + custom_diffusion_layers = AttnProcsLayers(unet.attn_processors) + + accelerator.register_for_checkpointing(custom_diffusion_layers) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.modifier_token is not None: + text_encoder.gradient_checkpointing_enable() + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + if args.with_prior_preservation: + args.learning_rate = args.learning_rate * 2.0 + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + optimizer = optimizer_class( + itertools.chain(text_encoder.get_input_embeddings().parameters(), custom_diffusion_layers.parameters()) + if args.modifier_token is not None + else custom_diffusion_layers.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Dataset and DataLoaders creation: + train_dataset = CustomDiffusionDataset( + concepts_list=args.concepts_list, + tokenizer=tokenizer, + with_prior_preservation=args.with_prior_preservation, + size=args.resolution, + mask_size=vae.encode( + torch.randn(1, 3, args.resolution, args.resolution).to(dtype=weight_dtype).to(accelerator.device) + ) + .latent_dist.sample() + .size()[-1], + center_crop=args.center_crop, + num_class_images=args.num_class_images, + hflip=args.hflip, + aug=not args.noaug, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + if args.modifier_token is not None: + custom_diffusion_layers, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + custom_diffusion_layers, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + custom_diffusion_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + custom_diffusion_layers, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + if args.modifier_token is not None: + text_encoder.train() + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(unet), accelerator.accumulate(text_encoder): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + mask = torch.chunk(batch["mask"], 2, dim=0)[0] + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean() + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + mask = batch["mask"] + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean() + accelerator.backward(loss) + # Zero out the gradients for all token embeddings except the newly added + # embeddings for the concept, as we only want to optimize the concept embeddings + if args.modifier_token is not None: + if accelerator.num_processes > 1: + grads_text_encoder = text_encoder.module.get_input_embeddings().weight.grad + else: + grads_text_encoder = text_encoder.get_input_embeddings().weight.grad + # Get the index for tokens that we want to zero the grads for + index_grads_to_zero = torch.arange(len(tokenizer)) != modifier_token_id[0] + for i in range(len(modifier_token_id[1:])): + index_grads_to_zero = index_grads_to_zero & ( + torch.arange(len(tokenizer)) != modifier_token_id[i] + ) + grads_text_encoder.data[index_grads_to_zero, :] = grads_text_encoder.data[ + index_grads_to_zero, : + ].fill_(0) + + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(text_encoder.parameters(), custom_diffusion_layers.parameters()) + if args.modifier_token is not None + else custom_diffusion_layers.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + tokenizer=tokenizer, + revision=args.revision, + ) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + images = [ + pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + # Save the custom diffusion layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = unet.to(torch.float32) + unet.save_attn_procs(args.output_dir) + save_new_embed(text_encoder, modifier_token_id, accelerator, args, args.output_dir) + + # Final inference + # Load previous pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype + ) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + + # load attention processors + pipeline.unet.load_attn_procs(args.output_dir, weight_name="pytorch_custom_diffusion_weights.bin") + for token in args.modifier_token: + pipeline.load_textual_inversion(args.output_dir, weight_name=f"{token}.bin") + + # run inference + if args.validation_prompt and args.num_validation_images > 0: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + images = [ + pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + prompt=args.instance_prompt, + repo_folder=args.output_dir, + ) + api = HfApi(token=args.hub_token) + api.upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/test_examples.py b/examples/test_examples.py index d9a1f86e53aa..a77fa4c7da23 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -221,6 +221,30 @@ def test_dreambooth_checkpointing(self): self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6"))) + def test_custom_diffusion(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/custom_diffusion/train_custom_diffusion.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 1.0e-05 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --modifier_token + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_custom_diffusion_weights.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, ".bin"))) + def test_text_to_image(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e814981a85c9..7f41ddaba498 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -17,7 +17,11 @@ import torch -from .models.attention_processor import LoRAAttnProcessor +from .models.attention_processor import ( + CustomDiffusionAttnProcessor, + CustomDiffusionXFormersAttnProcessor, + LoRAAttnProcessor, +) from .utils import ( DIFFUSERS_CACHE, HF_HUB_OFFLINE, @@ -46,6 +50,9 @@ TEXT_INVERSION_NAME = "learned_embeds.bin" TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" +CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin" +CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors" + class AttnProcsLayers(torch.nn.Module): def __init__(self, state_dict: Dict[str, torch.Tensor]): @@ -213,6 +220,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict attn_processors = {} is_lora = all("lora" in k for k in state_dict.keys()) + is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) if is_lora: lora_grouped_dict = defaultdict(dict) @@ -229,9 +237,38 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank ) attn_processors[key].load_state_dict(value_dict) - + elif is_custom_diffusion: + custom_diffusion_grouped_dict = defaultdict(dict) + for key, value in state_dict.items(): + if len(value) == 0: + custom_diffusion_grouped_dict[key] = {} + else: + if "to_out" in key: + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + else: + attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:]) + custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value + + for key, value_dict in custom_diffusion_grouped_dict.items(): + if len(value_dict) == 0: + attn_processors[key] = CustomDiffusionAttnProcessor( + train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None + ) + else: + cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1] + hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0] + train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False + attn_processors[key] = CustomDiffusionAttnProcessor( + train_kv=True, + train_q_out=train_q_out, + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + ) + attn_processors[key].load_state_dict(value_dict) else: - raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.") + raise ValueError( + f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training." + ) # set correct dtype & device attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()} @@ -285,16 +322,31 @@ def save_function(weights, filename): os.makedirs(save_directory, exist_ok=True) - model_to_save = AttnProcsLayers(self.attn_processors) - - # Save the model - state_dict = model_to_save.state_dict() + is_custom_diffusion = any( + isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)) + for (_, x) in self.attn_processors.items() + ) + if is_custom_diffusion: + model_to_save = AttnProcsLayers( + { + y: x + for (y, x) in self.attn_processors.items() + if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)) + } + ) + state_dict = model_to_save.state_dict() + for name, attn in self.attn_processors.items(): + if len(attn.state_dict()) == 0: + state_dict[name] = {} + else: + model_to_save = AttnProcsLayers(self.attn_processors) + state_dict = model_to_save.state_dict() if weight_name is None: if safe_serialization: - weight_name = LORA_WEIGHT_NAME_SAFE + weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE else: - weight_name = LORA_WEIGHT_NAME + weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME # Save the model save_function(state_dict, os.path.join(save_directory, weight_name)) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index f2a5a376bf39..b8787aed91f2 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -149,6 +149,9 @@ def set_use_memory_efficient_attention_xformers( is_lora = hasattr(self, "processor") and isinstance( self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor) ) + is_custom_diffusion = hasattr(self, "processor") and isinstance( + self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor) + ) if use_memory_efficient_attention_xformers: if self.added_kv_proj_dim is not None: @@ -192,6 +195,17 @@ def set_use_memory_efficient_attention_xformers( ) processor.load_state_dict(self.processor.state_dict()) processor.to(self.processor.to_q_lora.up.weight.device) + elif is_custom_diffusion: + processor = CustomDiffusionXFormersAttnProcessor( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) else: processor = XFormersAttnProcessor(attention_op=attention_op) else: @@ -203,6 +217,16 @@ def set_use_memory_efficient_attention_xformers( ) processor.load_state_dict(self.processor.state_dict()) processor.to(self.processor.to_q_lora.up.weight.device) + elif is_custom_diffusion: + processor = CustomDiffusionAttnProcessor( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) else: processor = AttnProcessor() @@ -459,6 +483,84 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a return hidden_states +class CustomDiffusionAttnProcessor(nn.Module): + def __init__( + self, + train_kv=True, + train_q_out=True, + hidden_size=None, + cross_attention_dim=None, + out_bias=True, + dropout=0.0, + ): + super().__init__() + self.train_kv = train_kv + self.train_q_out = train_q_out + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + + # `_custom_diffusion` id for easy serialization and loading. + if self.train_kv: + self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + if self.train_q_out: + self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) + self.to_out_custom_diffusion = nn.ModuleList([]) + self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) + self.to_out_custom_diffusion.append(nn.Dropout(dropout)) + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if self.train_q_out: + query = self.to_q_custom_diffusion(hidden_states) + else: + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + crossattn = False + encoder_hidden_states = hidden_states + else: + crossattn = True + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + if self.train_kv: + key = self.to_k_custom_diffusion(encoder_hidden_states) + value = self.to_v_custom_diffusion(encoder_hidden_states) + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if crossattn: + detach = torch.ones_like(key) + detach[:, :1, :] = detach[:, :1, :] * 0.0 + key = detach * key + (1 - detach) * key.detach() + value = detach * value + (1 - detach) * value.detach() + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + if self.train_q_out: + # linear proj + hidden_states = self.to_out_custom_diffusion[0](hidden_states) + # dropout + hidden_states = self.to_out_custom_diffusion[1](hidden_states) + else: + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + class AttnAddedKVProcessor: def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states @@ -699,6 +801,91 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a return hidden_states +class CustomDiffusionXFormersAttnProcessor(nn.Module): + def __init__( + self, + train_kv=True, + train_q_out=False, + hidden_size=None, + cross_attention_dim=None, + out_bias=True, + dropout=0.0, + attention_op: Optional[Callable] = None, + ): + super().__init__() + self.train_kv = train_kv + self.train_q_out = train_q_out + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.attention_op = attention_op + + # `_custom_diffusion` id for easy serialization and loading. + if self.train_kv: + self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + if self.train_q_out: + self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) + self.to_out_custom_diffusion = nn.ModuleList([]) + self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) + self.to_out_custom_diffusion.append(nn.Dropout(dropout)) + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if self.train_q_out: + query = self.to_q_custom_diffusion(hidden_states) + else: + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + crossattn = False + encoder_hidden_states = hidden_states + else: + crossattn = True + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + if self.train_kv: + key = self.to_k_custom_diffusion(encoder_hidden_states) + value = self.to_v_custom_diffusion(encoder_hidden_states) + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if crossattn: + detach = torch.ones_like(key) + detach[:, :1, :] = detach[:, :1, :] * 0.0 + key = detach * key + (1 - detach) * key.detach() + value = detach * value + (1 - detach) * value.detach() + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + if self.train_q_out: + # linear proj + hidden_states = self.to_out_custom_diffusion[0](hidden_states) + # dropout + hidden_states = self.to_out_custom_diffusion[1](hidden_states) + else: + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + class SlicedAttnProcessor: def __init__(self, slice_size): self.slice_size = slice_size @@ -834,4 +1021,6 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, AttnAddedKVProcessor2_0, LoRAAttnProcessor, LoRAXFormersAttnProcessor, + CustomDiffusionAttnProcessor, + CustomDiffusionXFormersAttnProcessor, ] diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 15f77fb8c106..2576297762a8 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -22,7 +22,7 @@ from parameterized import parameterized from diffusers import UNet2DConditionModel -from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, LoRAAttnProcessor from diffusers.utils import ( floats_tensor, load_hf_numpy, @@ -68,6 +68,55 @@ def create_lora_layers(model, mock_weights: bool = True): return lora_attn_procs +def create_custom_diffusion_layers(model, mock_weights: bool = True): + train_kv = True + train_q_out = True + custom_diffusion_attn_procs = {} + + st = model.state_dict() + for name, _ in model.attn_processors.items(): + cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + layer_name = name.split(".processor")[0] + weights = { + "to_k_custom_diffusion.weight": st[layer_name + ".to_k.weight"], + "to_v_custom_diffusion.weight": st[layer_name + ".to_v.weight"], + } + if train_q_out: + weights["to_q_custom_diffusion.weight"] = st[layer_name + ".to_q.weight"] + weights["to_out_custom_diffusion.0.weight"] = st[layer_name + ".to_out.0.weight"] + weights["to_out_custom_diffusion.0.bias"] = st[layer_name + ".to_out.0.bias"] + if cross_attention_dim is not None: + custom_diffusion_attn_procs[name] = CustomDiffusionAttnProcessor( + train_kv=train_kv, + train_q_out=train_q_out, + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + ).to(model.device) + custom_diffusion_attn_procs[name].load_state_dict(weights) + if mock_weights: + # add 1 to weights to mock trained weights + with torch.no_grad(): + custom_diffusion_attn_procs[name].to_k_custom_diffusion.weight += 1 + custom_diffusion_attn_procs[name].to_v_custom_diffusion.weight += 1 + else: + custom_diffusion_attn_procs[name] = CustomDiffusionAttnProcessor( + train_kv=False, + train_q_out=False, + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + ) + del st + return custom_diffusion_attn_procs + + class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): model_class = UNet2DConditionModel @@ -569,6 +618,96 @@ def test_lora_xformers_on_off(self): assert (sample - on_sample).abs().max() < 1e-4 assert (sample - off_sample).abs().max() < 1e-4 + def test_custom_diffusion_processors(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + sample1 = model(**inputs_dict).sample + + custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False) + + # make sure we can set a list of attention processors + model.set_attn_processor(custom_diffusion_attn_procs) + model.to(torch_device) + + # test that attn processors can be set to itself + model.set_attn_processor(model.attn_processors) + + with torch.no_grad(): + sample2 = model(**inputs_dict).sample + + assert (sample1 - sample2).abs().max() < 1e-4 + + def test_custom_diffusion_save_load(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + + custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False) + model.set_attn_processor(custom_diffusion_attn_procs) + + with torch.no_grad(): + sample = model(**inputs_dict).sample + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin") + + with torch.no_grad(): + new_sample = new_model(**inputs_dict).sample + + assert (sample - new_sample).abs().max() < 1e-4 + + # custom diffusion and no custom diffusion should be the same + assert (sample - old_sample).abs().max() < 1e-4 + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_custom_diffusion_xformers_on_off(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False) + model.set_attn_processor(custom_diffusion_attn_procs) + + # default + with torch.no_grad(): + sample = model(**inputs_dict).sample + + model.enable_xformers_memory_efficient_attention() + on_sample = model(**inputs_dict).sample + + model.disable_xformers_memory_efficient_attention() + off_sample = model(**inputs_dict).sample + + assert (sample - on_sample).abs().max() < 1e-4 + assert (sample - off_sample).abs().max() < 1e-4 + @slow class UNet2DConditionModelIntegrationTests(unittest.TestCase):