From 94ea42018024dcc0135487ef3338e4c25404288f Mon Sep 17 00:00:00 2001 From: neuron-party Date: Thu, 16 May 2024 14:22:48 -0700 Subject: [PATCH 1/2] syncing with main --- .github/workflows/build_docker_images.yml | 23 +- .github/workflows/pr_tests.yml | 2 +- .github/workflows/push_tests.yml | 1 + .github/workflows/push_tests_fast.yml | 2 +- .github/workflows/push_tests_mps.yml | 2 +- .github/workflows/run_test_from_pr.yml | 56 + CONTRIBUTING.md | 2 +- docs/source/en/using-diffusers/callback.md | 67 +- examples/community/README.md | 63 + .../pipeline_stable_diffusion_boxdiff.py | 1700 +++++++++++++++++ ...pipeline_stable_diffusion_upscale_ldm3d.py | 2 +- .../train_lcm_distill_lora_sdxl_wds.py | 2 +- examples/dreambooth/train_dreambooth.py | 2 +- .../realfill/requirements.txt | 2 +- examples/text_to_image/requirements.txt | 4 +- .../text_to_image/train_text_to_image_sdxl.py | 25 +- examples/vqgan/README.md | 127 ++ examples/vqgan/discriminator.py | 48 + examples/vqgan/requirements.txt | 8 + examples/vqgan/test_vqgan.py | 395 ++++ examples/vqgan/train_vqgan.py | 1067 +++++++++++ src/diffusers/callbacks.py | 156 ++ src/diffusers/commands/env.py | 110 +- src/diffusers/loaders/lora.py | 2 +- src/diffusers/loaders/textual_inversion.py | 27 +- src/diffusers/models/autoencoders/vae.py | 1 + src/diffusers/models/model_loading_utils.py | 149 ++ src/diffusers/models/modeling_utils.py | 121 +- .../models/unets/unet_2d_condition.py | 2 +- .../models/unets/unet_motion_model.py | 34 + src/diffusers/models/vq_model.py | 15 +- .../pipelines/controlnet/multicontrolnet.py | 10 +- .../controlnet/pipeline_controlnet.py | 18 +- .../controlnet/pipeline_controlnet_img2img.py | 18 +- .../controlnet/pipeline_controlnet_inpaint.py | 18 +- .../pipeline_controlnet_inpaint_sd_xl.py | 42 +- .../controlnet/pipeline_controlnet_sd_xl.py | 109 +- .../pipeline_controlnet_sd_xl_img2img.py | 34 +- .../controlnet_xs/pipeline_controlnet_xs.py | 18 +- .../pipeline_controlnet_xs_sd_xl.py | 28 +- .../versatile_diffusion/modeling_text_unet.py | 2 +- .../pixart_alpha/pipeline_pixart_alpha.py | 2 +- .../pixart_alpha/pipeline_pixart_sigma.py | 2 +- .../pipeline_onnx_stable_diffusion_upscale.py | 2 +- .../pipeline_stable_diffusion.py | 19 +- .../pipeline_stable_diffusion_img2img.py | 18 +- .../pipeline_stable_diffusion_inpaint.py | 18 +- ...eline_stable_diffusion_instruct_pix2pix.py | 18 +- ...ipeline_stable_diffusion_latent_upscale.py | 2 +- .../pipeline_stable_diffusion_upscale.py | 2 +- .../pipeline_stable_diffusion_diffedit.py | 2 +- .../pipeline_stable_diffusion_xl.py | 18 +- .../pipeline_stable_diffusion_xl_img2img.py | 18 +- .../pipeline_stable_diffusion_xl_inpaint.py | 18 +- .../scheduling_consistency_models.py | 6 +- .../schedulers/scheduling_edm_euler.py | 6 +- .../scheduling_euler_ancestral_discrete.py | 6 +- .../schedulers/scheduling_euler_discrete.py | 6 +- src/diffusers/utils/__init__.py | 5 + src/diffusers/utils/import_utils.py | 66 + src/diffusers/utils/testing_utils.py | 8 + tests/lora/test_lora_layers_sdxl.py | 5 +- tests/models/autoencoders/test_models_vq.py | 16 + .../test_controlnet_inpaint_sdxl.py | 45 +- .../test_controlnet_sdxl_img2img.py | 5 + .../test_stable_diffusion_single_file.py | 1 + 66 files changed, 4510 insertions(+), 318 deletions(-) create mode 100644 .github/workflows/run_test_from_pr.yml create mode 100644 examples/community/pipeline_stable_diffusion_boxdiff.py create mode 100644 examples/vqgan/README.md create mode 100644 examples/vqgan/discriminator.py create mode 100644 examples/vqgan/requirements.txt create mode 100644 examples/vqgan/test_vqgan.py create mode 100644 examples/vqgan/train_vqgan.py create mode 100644 src/diffusers/callbacks.py create mode 100644 src/diffusers/models/model_loading_utils.py diff --git a/.github/workflows/build_docker_images.yml b/.github/workflows/build_docker_images.yml index 82ef885b240e..f2f7709e86c1 100644 --- a/.github/workflows/build_docker_images.yml +++ b/.github/workflows/build_docker_images.yml @@ -90,24 +90,11 @@ jobs: - name: Post to a Slack channel id: slack - uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001 + uses: huggingface/hf-workflows/.github/actions/post-slack@main with: # Slack channel id, channel name, or user id to post message. # See also: https://api.slack.com/methods/chat.postMessage#channels - channel-id: ${{ env.CI_SLACK_CHANNEL }} - # For posting a rich message using Block Kit - payload: | - { - "text": "${{ matrix.image-name }} Docker Image build result: ${{ job.status }}\n${{ github.event.head_commit.url }}", - "blocks": [ - { - "type": "section", - "text": { - "type": "mrkdwn", - "text": "${{ matrix.image-name }} Docker Image build result: ${{ job.status }}\n${{ github.event.head_commit.url }}" - } - } - ] - } - env: - SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} + slack_channel: ${{ env.CI_SLACK_CHANNEL }} + title: "🤗 Results of the ${{ matrix.image-name }} Docker Image build" + status: ${{ job.status }} + slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index b1bed6568aa4..d5d1fc719305 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -156,7 +156,7 @@ jobs: if: ${{ matrix.config.framework == 'pytorch_examples' }} run: | python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" - python -m uv pip install peft + python -m uv pip install peft timm python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \ --make-reports=tests_${{ matrix.config.report }} \ examples diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index 66ced2182ff3..4dfef70e317a 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -426,6 +426,7 @@ jobs: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} run: | python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" + python -m uv pip install timm python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/ - name: Failure short reports diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml index 7c50da7b5c34..54ff48993768 100644 --- a/.github/workflows/push_tests_fast.yml +++ b/.github/workflows/push_tests_fast.yml @@ -107,7 +107,7 @@ jobs: if: ${{ matrix.config.framework == 'pytorch_examples' }} run: | python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" - python -m uv pip install peft + python -m uv pip install peft timm python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \ --make-reports=tests_${{ matrix.config.report }} \ examples diff --git a/.github/workflows/push_tests_mps.yml b/.github/workflows/push_tests_mps.yml index 3a14f856346b..60165bc41471 100644 --- a/.github/workflows/push_tests_mps.yml +++ b/.github/workflows/push_tests_mps.yml @@ -23,7 +23,7 @@ concurrency: jobs: run_fast_tests_apple_m1: name: Fast PyTorch MPS tests on MacOS - runs-on: [ self-hosted, apple-m1 ] + runs-on: macos-13-xlarge steps: - name: Checkout diffusers diff --git a/.github/workflows/run_test_from_pr.yml b/.github/workflows/run_test_from_pr.yml new file mode 100644 index 000000000000..27081ac584d6 --- /dev/null +++ b/.github/workflows/run_test_from_pr.yml @@ -0,0 +1,56 @@ +name: Run (SLOW) desired tests on our runner from a PR (applicable to GPUs only at the moment) + +on: + workflow_dispatch: + inputs: + pr_number: + description: 'PR number' + required: true + docker_image: + default: 'diffusers/diffusers-pytorch-cuda' + description: 'Name of the Docker image' + required: true + test_command: + description: 'Test command to run (e.g.: `pytest tests/pipelines/dit/`). Any valid pytest command can be provided.' + required: true + +env: + IS_GITHUB_CI: "1" + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + HF_HOME: /mnt/cache + DIFFUSERS_IS_CI: yes + OMP_NUM_THREADS: 8 + MKL_NUM_THREADS: 8 + RUN_SLOW: yes + +jobs: + run_tests: + name: "Run a test on our runner from a PR" + runs-on: [single-gpu, nvidia-gpu, "t4", ci] + container: + image: ${{ github.event.inputs.docker_image }} + options: --gpus all --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ + + steps: + - name: NVIDIA-SMI + run: | + nvidia-smi + + - uses: actions/checkout@v3 + - name: Install `gh` + run: | + : # see https://github.com/cli/cli/blob/trunk/docs/install_linux.md#debian-ubuntu-linux-raspberry-pi-os-apt + (type -p wget >/dev/null || (apt update && apt-get install wget -y)) \ + && mkdir -p -m 755 /etc/apt/keyrings \ + && wget -qO- https://cli.github.com/packages/githubcli-archive-keyring.gpg | tee /etc/apt/keyrings/githubcli-archive-keyring.gpg > /dev/null \ + && chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \ + && echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null \ + && apt update \ + && apt install gh -y + + - name: Checkout the PR branch + run: | + gh pr checkout ${{ github.event.inputs.pr_number }} + + - name: Run tests + run: ${{ github.event.inputs.test_command }} \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 887e4dd43c45..59d39155952a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -355,7 +355,7 @@ You will need basic `git` proficiency to be able to contribute to manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro Git](https://git-scm.com/book/en/v2) is a very good reference. -Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/main/setup.py#L265)): +Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/42f25d601a910dceadaee6c44345896b4cfa9928/setup.py#L270)): 1. Fork the [repository](https://github.com/huggingface/diffusers) by clicking on the 'Fork' button on the repository's page. This creates a copy of the code diff --git a/docs/source/en/using-diffusers/callback.md b/docs/source/en/using-diffusers/callback.md index 3f3e8dae9f2d..7445513dbf4b 100644 --- a/docs/source/en/using-diffusers/callback.md +++ b/docs/source/en/using-diffusers/callback.md @@ -19,13 +19,74 @@ The denoising loop of a pipeline can be modified with custom defined functions u This guide will demonstrate how callbacks work by a few features you can implement with them. +## Official callbacks + +We provide a list of callbacks you can plug into an existing pipeline and modify the denoising loop. This is the current list of official callbacks: + +- `SDCFGCutoffCallback`: Disables the CFG after a certain number of steps for all SD 1.5 pipelines, including text-to-image, image-to-image, inpaint, and controlnet. +- `SDXLCFGCutoffCallback`: Disables the CFG after a certain number of steps for all SDXL pipelines, including text-to-image, image-to-image, inpaint, and controlnet. +- `IPAdapterScaleCutoffCallback`: Disables the IP Adapter after a certain number of steps for all pipelines supporting IP-Adapter. + +> [!TIP] +> If you want to add a new official callback, feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) or [submit a PR](https://huggingface.co/docs/diffusers/main/en/conceptual/contribution#how-to-open-a-pr). + +To set up a callback, you need to specify the number of denoising steps after which the callback comes into effect. You can do so by using either one of these two arguments + +- `cutoff_step_ratio`: Float number with the ratio of the steps. +- `cutoff_step_index`: Integer number with the exact number of the step. + +```python +import torch + +from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLPipeline +from diffusers.callbacks import SDXLCFGCutoffCallback + + +callback = SDXLCFGCutoffCallback(cutoff_step_ratio=0.4) +# can also be used with cutoff_step_index +# callback = SDXLCFGCutoffCallback(cutoff_step_ratio=None, cutoff_step_index=10) + +pipeline = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16, + variant="fp16", +).to("cuda") +pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True) + +prompt = "a sports car at the road, best quality, high quality, high detail, 8k resolution" + +generator = torch.Generator(device="cpu").manual_seed(2628670641) + +out = pipeline( + prompt=prompt, + negative_prompt="", + guidance_scale=6.5, + num_inference_steps=25, + generator=generator, + callback_on_step_end=callback, +) + +out.images[0].save("official_callback.png") +``` + +
+
+ generated image of a sports car at the road +
without SDXLCFGCutoffCallback
+
+
+ generated image of a a sports car at the road with cfg callback +
with SDXLCFGCutoffCallback
+
+
+ ## Dynamic classifier-free guidance Dynamic classifier-free guidance (CFG) is a feature that allows you to disable CFG after a certain number of inference steps which can help you save compute with minimal cost to performance. The callback function for this should have the following arguments: -* `pipeline` (or the pipeline instance) provides access to important properties such as `num_timesteps` and `guidance_scale`. You can modify these properties by updating the underlying attributes. For this example, you'll disable CFG by setting `pipeline._guidance_scale=0.0`. -* `step_index` and `timestep` tell you where you are in the denoising loop. Use `step_index` to turn off CFG after reaching 40% of `num_timesteps`. -* `callback_kwargs` is a dict that contains tensor variables you can modify during the denoising loop. It only includes variables specified in the `callback_on_step_end_tensor_inputs` argument, which is passed to the pipeline's `__call__` method. Different pipelines may use different sets of variables, so please check a pipeline's `_callback_tensor_inputs` attribute for the list of variables you can modify. Some common variables include `latents` and `prompt_embeds`. For this function, change the batch size of `prompt_embeds` after setting `guidance_scale=0.0` in order for it to work properly. +- `pipeline` (or the pipeline instance) provides access to important properties such as `num_timesteps` and `guidance_scale`. You can modify these properties by updating the underlying attributes. For this example, you'll disable CFG by setting `pipeline._guidance_scale=0.0`. +- `step_index` and `timestep` tell you where you are in the denoising loop. Use `step_index` to turn off CFG after reaching 40% of `num_timesteps`. +- `callback_kwargs` is a dict that contains tensor variables you can modify during the denoising loop. It only includes variables specified in the `callback_on_step_end_tensor_inputs` argument, which is passed to the pipeline's `__call__` method. Different pipelines may use different sets of variables, so please check a pipeline's `_callback_tensor_inputs` attribute for the list of variables you can modify. Some common variables include `latents` and `prompt_embeds`. For this function, change the batch size of `prompt_embeds` after setting `guidance_scale=0.0` in order for it to work properly. Your callback function should look something like this: diff --git a/examples/community/README.md b/examples/community/README.md index 5cebc4f9f049..8afe8d42e3d4 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -68,6 +68,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif | InstantID Pipeline | Stable Diffusion XL Pipeline that supports InstantID | [InstantID Pipeline](#instantid-pipeline) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/InstantX/InstantID) | [Haofan Wang](https://github.com/haofanwang) | | UFOGen Scheduler | Scheduler for UFOGen Model (compatible with Stable Diffusion pipelines) | [UFOGen Scheduler](#ufogen-scheduler) | - | [dg845](https://github.com/dg845) | | Stable Diffusion XL IPEX Pipeline | Accelerate Stable Diffusion XL inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion XL on IPEX](#stable-diffusion-xl-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) | +| Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) | To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly. @@ -1676,6 +1677,68 @@ image = pipe(prompt, image=input_image, strength=0.75,).images[0] image.save('tensorrt_img2img_new_zealand_hills.png') ``` +### Stable Diffusion BoxDiff +BoxDiff is a training-free method for controlled generation with bounding box coordinates. It shoud work with any Stable Diffusion model. Below shows an example with `stable-diffusion-2-1-base`. +```py +import torch +from PIL import Image, ImageDraw +from copy import deepcopy + +from examples.community.pipeline_stable_diffusion_boxdiff import StableDiffusionBoxDiffPipeline + +def draw_box_with_text(img, boxes, names): + colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"] + img_new = deepcopy(img) + draw = ImageDraw.Draw(img_new) + + W, H = img.size + for bid, box in enumerate(boxes): + draw.rectangle([box[0] * W, box[1] * H, box[2] * W, box[3] * H], outline=colors[bid % len(colors)], width=4) + draw.text((box[0] * W, box[1] * H), names[bid], fill=colors[bid % len(colors)]) + return img_new + +pipe = StableDiffusionBoxDiffPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1-base", + torch_dtype=torch.float16, +) +pipe.to("cuda") + +# example 1 +prompt = "as the aurora lights up the sky, a herd of reindeer leisurely wanders on the grassy meadow, admiring the breathtaking view, a serene lake quietly reflects the magnificent display, and in the distance, a snow-capped mountain stands majestically, fantasy, 8k, highly detailed" +phrases = [ + "aurora", + "reindeer", + "meadow", + "lake", + "mountain" +] +boxes = [[1,3,512,202], [75,344,421,495], [1,327,508,507], [2,217,507,341], [1,135,509,242]] + +# example 2 +# prompt = "A rabbit wearing sunglasses looks very proud" +# phrases = ["rabbit", "sunglasses"] +# boxes = [[67,87,366,512], [66,130,364,262]] + +boxes = [[x / 512 for x in box] for box in boxes] + +images = pipe( + prompt, + boxdiff_phrases=phrases, + boxdiff_boxes=boxes, + boxdiff_kwargs={ + "attention_res": 16, + "normalize_eot": True + }, + num_inference_steps=50, + guidance_scale=7.5, + generator=torch.manual_seed(42), + safety_checker=None +).images + +draw_box_with_text(images[0], boxes, phrases).save("output.png") +``` + + ### Stable Diffusion Reference This pipeline uses the Reference Control. Refer to the [sd-webui-controlnet discussion: Reference-only Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1236)[sd-webui-controlnet discussion: Reference-adain Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1280). diff --git a/examples/community/pipeline_stable_diffusion_boxdiff.py b/examples/community/pipeline_stable_diffusion_boxdiff.py new file mode 100644 index 000000000000..f82533944132 --- /dev/null +++ b/examples/community/pipeline_stable_diffusion_boxdiff.py @@ -0,0 +1,1700 @@ +# Copyright 2024 Jingyang Zhang and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import inspect +import math +import numbers +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from diffusers.models.attention_processor import Attention, FusedAttnProcessor2_0 +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPipeline + + >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +class GaussianSmoothing(nn.Module): + """ + Copied from official repo: https://github.com/showlab/BoxDiff/blob/master/utils/gaussian_smoothing.py + Apply gaussian smoothing on a + 1d, 2d or 3d tensor. Filtering is performed seperately for each channel + in the input using a depthwise convolution. + Arguments: + channels (int, sequence): Number of channels of the input tensors. Output will + have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. + sigma (float, sequence): Standard deviation of the gaussian kernel. + dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ + + def __init__(self, channels, kernel_size, sigma, dim=2): + super(GaussianSmoothing, self).__init__() + if isinstance(kernel_size, numbers.Number): + kernel_size = [kernel_size] * dim + if isinstance(sigma, numbers.Number): + sigma = [sigma] * dim + + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer("weight", kernel) + self.groups = channels + + if dim == 1: + self.conv = F.conv1d + elif dim == 2: + self.conv = F.conv2d + elif dim == 3: + self.conv = F.conv3d + else: + raise RuntimeError("Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim)) + + def forward(self, input): + """ + Apply gaussian filter to input. + Arguments: + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups) + + +class AttendExciteCrossAttnProcessor: + def __init__(self, attnstore, place_in_unet): + super().__init__() + self.attnstore = attnstore + self.place_in_unet = place_in_unet + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size=1) + query = attn.to_q(hidden_states) + + is_cross = encoder_hidden_states is not None + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + self.attnstore(attention_probs, is_cross, self.place_in_unet) + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class AttentionControl(abc.ABC): + def step_callback(self, x_t): + return x_t + + def between_steps(self): + return + + # @property + # def num_uncond_att_layers(self): + # return 0 + + @abc.abstractmethod + def forward(self, attn, is_cross: bool, place_in_unet: str): + raise NotImplementedError + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + if self.cur_att_layer >= self.num_uncond_att_layers: + self.forward(attn, is_cross, place_in_unet) + self.cur_att_layer += 1 + if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: + self.cur_att_layer = 0 + self.cur_step += 1 + self.between_steps() + + def reset(self): + self.cur_step = 0 + self.cur_att_layer = 0 + + def __init__(self): + self.cur_step = 0 + self.num_att_layers = -1 + self.cur_att_layer = 0 + + +class AttentionStore(AttentionControl): + @staticmethod + def get_empty_store(): + return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []} + + def forward(self, attn, is_cross: bool, place_in_unet: str): + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + if attn.shape[1] <= 32**2: # avoid memory overhead + self.step_store[key].append(attn) + return attn + + def between_steps(self): + self.attention_store = self.step_store + if self.save_global_store: + with torch.no_grad(): + if len(self.global_store) == 0: + self.global_store = self.step_store + else: + for key in self.global_store: + for i in range(len(self.global_store[key])): + self.global_store[key][i] += self.step_store[key][i].detach() + self.step_store = self.get_empty_store() + self.step_store = self.get_empty_store() + + def get_average_attention(self): + average_attention = self.attention_store + return average_attention + + def get_average_global_attention(self): + average_attention = { + key: [item / self.cur_step for item in self.global_store[key]] for key in self.attention_store + } + return average_attention + + def reset(self): + super(AttentionStore, self).reset() + self.step_store = self.get_empty_store() + self.attention_store = {} + self.global_store = {} + + def __init__(self, save_global_store=False): + """ + Initialize an empty AttentionStore + :param step_index: used to visualize only a specific step in the diffusion process + """ + super(AttentionStore, self).__init__() + self.save_global_store = save_global_store + self.step_store = self.get_empty_store() + self.attention_store = {} + self.global_store = {} + self.curr_step_index = 0 + self.num_uncond_att_layers = 0 + + +def aggregate_attention( + attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int +) -> torch.Tensor: + """Aggregates the attention across the different layers and heads at the specified resolution.""" + out = [] + attention_maps = attention_store.get_average_attention() + + # for k, v in attention_maps.items(): + # for vv in v: + # print(vv.shape) + # exit() + + num_pixels = res**2 + for location in from_where: + for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: + if item.shape[1] == num_pixels: + cross_maps = item.reshape(1, -1, res, res, item.shape[-1])[select] + out.append(cross_maps) + out = torch.cat(out, dim=0) + out = out.sum(0) / out.shape[0] + return out + + +def register_attention_control(model, controller): + attn_procs = {} + cross_att_count = 0 + for name in model.unet.attn_processors.keys(): + # cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim + if name.startswith("mid_block"): + # hidden_size = model.unet.config.block_out_channels[-1] + place_in_unet = "mid" + elif name.startswith("up_blocks"): + # block_id = int(name[len("up_blocks.")]) + # hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id] + place_in_unet = "up" + elif name.startswith("down_blocks"): + # block_id = int(name[len("down_blocks.")]) + # hidden_size = model.unet.config.block_out_channels[block_id] + place_in_unet = "down" + else: + continue + + cross_att_count += 1 + attn_procs[name] = AttendExciteCrossAttnProcessor(attnstore=controller, place_in_unet=place_in_unet) + model.unet.set_attn_processor(attn_procs) + controller.num_att_layers = cross_att_count + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionBoxDiffPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with BoxDiff. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return text_inputs, prompt_embeds, negative_prompt_embeds + + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + boxdiff_phrases, + boxdiff_boxes, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if boxdiff_phrases is not None or boxdiff_boxes is not None: + if not (boxdiff_phrases is not None and boxdiff_boxes is not None): + raise ValueError("Either both `boxdiff_phrases` and `boxdiff_boxes` must be passed or none of them.") + + if not isinstance(boxdiff_phrases, list) or not isinstance(boxdiff_boxes, list): + raise ValueError("`boxdiff_phrases` and `boxdiff_boxes` must be lists.") + + if len(boxdiff_phrases) != len(boxdiff_boxes): + raise ValueError( + "`boxdiff_phrases` and `boxdiff_boxes` must have the same length," + f" got: `boxdiff_phrases` {len(boxdiff_phrases)} != `boxdiff_boxes`" + f" {len(boxdiff_boxes)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections + def fuse_qkv_projections(self, unet: bool = True, vae: bool = True): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + + Args: + unet (`bool`, defaults to `True`): To apply fusion on the UNet. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + """ + self.fusing_unet = False + self.fusing_vae = False + + if unet: + self.fusing_unet = True + self.unet.fuse_qkv_projections() + self.unet.set_attn_processor(FusedAttnProcessor2_0()) + + if vae: + if not isinstance(self.vae, AutoencoderKL): + raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.") + + self.fusing_vae = True + self.vae.fuse_qkv_projections() + self.vae.set_attn_processor(FusedAttnProcessor2_0()) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections + def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True): + """Disable QKV projection fusion if enabled. + + + + This API is 🧪 experimental. + + + + Args: + unet (`bool`, defaults to `True`): To apply fusion on the UNet. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + + """ + if unet: + if not self.fusing_unet: + logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.") + else: + self.unet.unfuse_qkv_projections() + self.fusing_unet = False + + if vae: + if not self.fusing_vae: + logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.") + else: + self.vae.unfuse_qkv_projections() + self.fusing_vae = False + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def _compute_max_attention_per_index( + self, + attention_maps: torch.Tensor, + indices_to_alter: List[int], + smooth_attentions: bool = False, + sigma: float = 0.5, + kernel_size: int = 3, + normalize_eot: bool = False, + bboxes: List[int] = None, + L: int = 1, + P: float = 0.2, + ) -> List[torch.Tensor]: + """Computes the maximum attention value for each of the tokens we wish to alter.""" + last_idx = -1 + if normalize_eot: + prompt = self.prompt + if isinstance(self.prompt, list): + prompt = self.prompt[0] + last_idx = len(self.tokenizer(prompt)["input_ids"]) - 1 + attention_for_text = attention_maps[:, :, 1:last_idx] + attention_for_text *= 100 + attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1) + + # Shift indices since we removed the first token "1:last_idx" + indices_to_alter = [index - 1 for index in indices_to_alter] + + # Extract the maximum values + max_indices_list_fg = [] + max_indices_list_bg = [] + dist_x = [] + dist_y = [] + + cnt = 0 + for i in indices_to_alter: + image = attention_for_text[:, :, i] + + # TODO + # box = [max(round(b / (512 / image.shape[0])), 0) for b in bboxes[cnt]] + # x1, y1, x2, y2 = box + H, W = image.shape + x1 = min(max(round(bboxes[cnt][0] * W), 0), W) + y1 = min(max(round(bboxes[cnt][1] * H), 0), H) + x2 = min(max(round(bboxes[cnt][2] * W), 0), W) + y2 = min(max(round(bboxes[cnt][3] * H), 0), H) + box = [x1, y1, x2, y2] + cnt += 1 + + # coordinates to masks + obj_mask = torch.zeros_like(image) + ones_mask = torch.ones([y2 - y1, x2 - x1], dtype=obj_mask.dtype).to(obj_mask.device) + obj_mask[y1:y2, x1:x2] = ones_mask + bg_mask = 1 - obj_mask + + if smooth_attentions: + smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).to(image.device) + input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode="reflect") + image = smoothing(input).squeeze(0).squeeze(0) + + # Inner-Box constraint + k = (obj_mask.sum() * P).long() + max_indices_list_fg.append((image * obj_mask).reshape(-1).topk(k)[0].mean()) + + # Outer-Box constraint + k = (bg_mask.sum() * P).long() + max_indices_list_bg.append((image * bg_mask).reshape(-1).topk(k)[0].mean()) + + # Corner Constraint + gt_proj_x = torch.max(obj_mask, dim=0)[0] + gt_proj_y = torch.max(obj_mask, dim=1)[0] + corner_mask_x = torch.zeros_like(gt_proj_x) + corner_mask_y = torch.zeros_like(gt_proj_y) + + # create gt according to the number config.L + N = gt_proj_x.shape[0] + corner_mask_x[max(box[0] - L, 0) : min(box[0] + L + 1, N)] = 1.0 + corner_mask_x[max(box[2] - L, 0) : min(box[2] + L + 1, N)] = 1.0 + corner_mask_y[max(box[1] - L, 0) : min(box[1] + L + 1, N)] = 1.0 + corner_mask_y[max(box[3] - L, 0) : min(box[3] + L + 1, N)] = 1.0 + dist_x.append((F.l1_loss(image.max(dim=0)[0], gt_proj_x, reduction="none") * corner_mask_x).mean()) + dist_y.append((F.l1_loss(image.max(dim=1)[0], gt_proj_y, reduction="none") * corner_mask_y).mean()) + + return max_indices_list_fg, max_indices_list_bg, dist_x, dist_y + + def _aggregate_and_get_max_attention_per_token( + self, + attention_store: AttentionStore, + indices_to_alter: List[int], + attention_res: int = 16, + smooth_attentions: bool = False, + sigma: float = 0.5, + kernel_size: int = 3, + normalize_eot: bool = False, + bboxes: List[int] = None, + L: int = 1, + P: float = 0.2, + ): + """Aggregates the attention for each token and computes the max activation value for each token to alter.""" + attention_maps = aggregate_attention( + attention_store=attention_store, + res=attention_res, + from_where=("up", "down", "mid"), + is_cross=True, + select=0, + ) + max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._compute_max_attention_per_index( + attention_maps=attention_maps, + indices_to_alter=indices_to_alter, + smooth_attentions=smooth_attentions, + sigma=sigma, + kernel_size=kernel_size, + normalize_eot=normalize_eot, + bboxes=bboxes, + L=L, + P=P, + ) + return max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y + + @staticmethod + def _compute_loss( + max_attention_per_index_fg: List[torch.Tensor], + max_attention_per_index_bg: List[torch.Tensor], + dist_x: List[torch.Tensor], + dist_y: List[torch.Tensor], + return_losses: bool = False, + ) -> torch.Tensor: + """Computes the attend-and-excite loss using the maximum attention value for each token.""" + losses_fg = [max(0, 1.0 - curr_max) for curr_max in max_attention_per_index_fg] + losses_bg = [max(0, curr_max) for curr_max in max_attention_per_index_bg] + loss = sum(losses_fg) + sum(losses_bg) + sum(dist_x) + sum(dist_y) + if return_losses: + return max(losses_fg), losses_fg + else: + return max(losses_fg), loss + + @staticmethod + def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor: + """Update the latent according to the computed loss.""" + grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0] + latents = latents - step_size * grad_cond + return latents + + def _perform_iterative_refinement_step( + self, + latents: torch.Tensor, + indices_to_alter: List[int], + loss_fg: torch.Tensor, + threshold: float, + text_embeddings: torch.Tensor, + text_input, + attention_store: AttentionStore, + step_size: float, + t: int, + attention_res: int = 16, + smooth_attentions: bool = True, + sigma: float = 0.5, + kernel_size: int = 3, + max_refinement_steps: int = 20, + normalize_eot: bool = False, + bboxes: List[int] = None, + L: int = 1, + P: float = 0.2, + ): + """ + Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent + code according to our loss objective until the given threshold is reached for all tokens. + """ + iteration = 0 + target_loss = max(0, 1.0 - threshold) + + while loss_fg > target_loss: + iteration += 1 + + latents = latents.clone().detach().requires_grad_(True) + # noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample + self.unet.zero_grad() + + # Get max activation value for each subject token + ( + max_attention_per_index_fg, + max_attention_per_index_bg, + dist_x, + dist_y, + ) = self._aggregate_and_get_max_attention_per_token( + attention_store=attention_store, + indices_to_alter=indices_to_alter, + attention_res=attention_res, + smooth_attentions=smooth_attentions, + sigma=sigma, + kernel_size=kernel_size, + normalize_eot=normalize_eot, + bboxes=bboxes, + L=L, + P=P, + ) + + loss_fg, losses_fg = self._compute_loss( + max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y, return_losses=True + ) + + if loss_fg != 0: + latents = self._update_latent(latents, loss_fg, step_size) + + # with torch.no_grad(): + # noise_pred_uncond = self.unet(latents, t, encoder_hidden_states=text_embeddings[0].unsqueeze(0)).sample + # noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample + + # try: + # low_token = np.argmax([l.item() if not isinstance(l, int) else l for l in losses_fg]) + # except Exception as e: + # print(e) # catch edge case :) + # low_token = np.argmax(losses_fg) + + # low_word = self.tokenizer.decode(text_input.input_ids[0][indices_to_alter[low_token]]) + # print(f'\t Try {iteration}. {low_word} has a max attention of {max_attention_per_index_fg[low_token]}') + + if iteration >= max_refinement_steps: + # print(f'\t Exceeded max number of iterations ({max_refinement_steps})! ' + # f'Finished with a max attention of {max_attention_per_index_fg[low_token]}') + break + + # Run one more time but don't compute gradients and update the latents. + # We just need to compute the new loss - the grad update will occur below + latents = latents.clone().detach().requires_grad_(True) + # noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample + self.unet.zero_grad() + + # Get max activation value for each subject token + ( + max_attention_per_index_fg, + max_attention_per_index_bg, + dist_x, + dist_y, + ) = self._aggregate_and_get_max_attention_per_token( + attention_store=attention_store, + indices_to_alter=indices_to_alter, + attention_res=attention_res, + smooth_attentions=smooth_attentions, + sigma=sigma, + kernel_size=kernel_size, + normalize_eot=normalize_eot, + bboxes=bboxes, + L=L, + P=P, + ) + loss_fg, losses_fg = self._compute_loss( + max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y, return_losses=True + ) + # print(f"\t Finished with loss of: {loss_fg}") + return loss_fg, latents, max_attention_per_index_fg + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + boxdiff_phrases: List[str] = None, + boxdiff_boxes: List[List[float]] = None, # TODO + boxdiff_kwargs: Optional[Dict[str, Any]] = { + "attention_res": 16, + "P": 0.2, + "L": 1, + "max_iter_to_alter": 25, + "loss_thresholds": {0: 0.05, 10: 0.5, 20: 0.8}, + "scale_factor": 20, + "scale_range": (1.0, 0.5), + "smooth_attentions": True, + "sigma": 0.5, + "kernel_size": 3, + "refine": False, + "normalize_eot": True, + }, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + + boxdiff_attention_res (`int`, *optional*, defaults to 16): + The resolution of the attention maps used for computing the BoxDiff loss. + boxdiff_P (`float`, *optional*, defaults to 0.2): + + boxdiff_L (`int`, *optional*, defaults to 1): + The number of pixels around the corner to be selected in BoxDiff loss. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + # -1. Register attention control (for BoxDiff) + attention_store = AttentionStore() + register_attention_control(self, attention_store) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + boxdiff_phrases, + boxdiff_boxes, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + self.prompt = prompt + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + text_inputs, prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 6.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 6.3 Prepare BoxDiff inputs + # a) Indices to alter + input_ids = self.tokenizer(prompt)["input_ids"] + decoded = [self.tokenizer.decode([t]) for t in input_ids] + indices_to_alter = [] + bboxes = [] + for phrase, box in zip(boxdiff_phrases, boxdiff_boxes): + # it could happen that phrase does not correspond a single token? + if phrase not in decoded: + continue + indices_to_alter.append(decoded.index(phrase)) + bboxes.append(box) + + # b) A bunch of hyperparameters + attention_res = boxdiff_kwargs.get("attention_res", 16) + smooth_attentions = boxdiff_kwargs.get("smooth_attentions", True) + sigma = boxdiff_kwargs.get("sigma", 0.5) + kernel_size = boxdiff_kwargs.get("kernel_size", 3) + L = boxdiff_kwargs.get("L", 1) + P = boxdiff_kwargs.get("P", 0.2) + thresholds = boxdiff_kwargs.get("loss_thresholds", {0: 0.05, 10: 0.5, 20: 0.8}) + max_iter_to_alter = boxdiff_kwargs.get("max_iter_to_alter", len(self.scheduler.timesteps) + 1) + scale_factor = boxdiff_kwargs.get("scale_factor", 20) + refine = boxdiff_kwargs.get("refine", False) + normalize_eot = boxdiff_kwargs.get("normalize_eot", True) + + scale_range = boxdiff_kwargs.get("scale_range", (1.0, 0.5)) + scale_range = np.linspace(scale_range[0], scale_range[1], len(self.scheduler.timesteps)) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # BoxDiff optimization + with torch.enable_grad(): + latents = latents.clone().detach().requires_grad_(True) + + # Forward pass of denoising with text conditioning + noise_pred_text = self.unet( + latents, + t, + encoder_hidden_states=prompt_embeds[1].unsqueeze(0), + cross_attention_kwargs=cross_attention_kwargs, + ).sample + self.unet.zero_grad() + + # Get max activation value for each subject token + ( + max_attention_per_index_fg, + max_attention_per_index_bg, + dist_x, + dist_y, + ) = self._aggregate_and_get_max_attention_per_token( + attention_store=attention_store, + indices_to_alter=indices_to_alter, + attention_res=attention_res, + smooth_attentions=smooth_attentions, + sigma=sigma, + kernel_size=kernel_size, + normalize_eot=normalize_eot, + bboxes=bboxes, + L=L, + P=P, + ) + + loss_fg, loss = self._compute_loss( + max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y + ) + + # Refinement from attend-and-excite (not necessary) + if refine and i in thresholds.keys() and loss_fg > 1.0 - thresholds[i]: + del noise_pred_text + torch.cuda.empty_cache() + loss_fg, latents, max_attention_per_index_fg = self._perform_iterative_refinement_step( + latents=latents, + indices_to_alter=indices_to_alter, + loss_fg=loss_fg, + threshold=thresholds[i], + text_embeddings=prompt_embeds, + text_input=text_inputs, + attention_store=attention_store, + step_size=scale_factor * np.sqrt(scale_range[i]), + t=t, + attention_res=attention_res, + smooth_attentions=smooth_attentions, + sigma=sigma, + kernel_size=kernel_size, + normalize_eot=normalize_eot, + bboxes=bboxes, + L=L, + P=P, + ) + + # Perform gradient update + if i < max_iter_to_alter: + _, loss = self._compute_loss( + max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y + ) + if loss != 0: + latents = self._update_latent( + latents=latents, loss=loss, step_size=scale_factor * np.sqrt(scale_range[i]) + ) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py b/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py index 0622db005d76..a873e7b2956e 100644 --- a/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py +++ b/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py @@ -460,7 +460,7 @@ def check_inputs( ) # verify batch size of prompt and image are same if image is a list or tensor or numpy array - if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray): + if isinstance(image, (list, np.ndarray, torch.Tensor)): if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index 08d6b23d6deb..ce3e7f624843 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -1358,7 +1358,7 @@ def compute_embeddings( # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # solver timestep. with torch.no_grad(): - if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: + if torch.backends.mps.is_available() or "playground" in args.pretrained_teacher_model: autocast_ctx = nullcontext() else: autocast_ctx = torch.autocast(accelerator.device.type) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index a18c443e7d4d..103e3b5b10b8 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -759,7 +759,7 @@ def __getitem__(self, index): def model_has_vae(args): - config_file_name = os.path.join("vae", AutoencoderKL.config_name) + config_file_name = Path("vae", AutoencoderKL.config_name).as_posix() if os.path.isdir(args.pretrained_model_name_or_path): config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name) return os.path.isfile(config_file_name) diff --git a/examples/research_projects/realfill/requirements.txt b/examples/research_projects/realfill/requirements.txt index f6abdc6e7e20..c9e4e7e4ae72 100644 --- a/examples/research_projects/realfill/requirements.txt +++ b/examples/research_projects/realfill/requirements.txt @@ -1,6 +1,6 @@ diffusers==0.20.1 accelerate==0.23.0 -transformers==4.36.0 +transformers==4.38.0 peft==0.5.0 torch==2.0.1 torchvision>=0.16 diff --git a/examples/text_to_image/requirements.txt b/examples/text_to_image/requirements.txt index 0dd164fc2035..c3ffa42f0edc 100644 --- a/examples/text_to_image/requirements.txt +++ b/examples/text_to_image/requirements.txt @@ -1,8 +1,8 @@ accelerate>=0.16.0 torchvision transformers>=4.25.1 -datasets +datasets>=2.19.1 ftfy tensorboard Jinja2 -peft==0.7.0 \ No newline at end of file +peft==0.7.0 diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 90602ad597a9..74864da20d82 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -35,7 +35,7 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import ProjectConfiguration, set_seed +from accelerate.utils import DistributedType, ProjectConfiguration, set_seed from datasets import concatenate_datasets, load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version @@ -50,7 +50,7 @@ from diffusers.training_utils import EMAModel, compute_snr from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card -from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available from diffusers.utils.torch_utils import is_compiled_module @@ -58,7 +58,8 @@ check_min_version("0.28.0.dev0") logger = get_logger(__name__) - +if is_torch_npu_available(): + torch.npu.config.allow_internal_format = False DATASET_NAME_MAPPING = { "lambdalabs/naruto-blip-captions": ("image", "text"), @@ -460,6 +461,9 @@ def parse_args(input_args=None): ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention." + ) parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) @@ -716,7 +720,12 @@ def main(args): args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) - + if args.enable_npu_flash_attention: + if is_torch_npu_available(): + logger.info("npu flash attention enabled.") + unet.enable_npu_flash_attention() + else: + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.") if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers @@ -742,7 +751,8 @@ def save_model_hook(models, weights, output_dir): model.save_pretrained(os.path.join(output_dir, "unet")) # make sure to pop weight so that corresponding model is not saved again - weights.pop() + if weights: + weights.pop() def load_model_hook(models, input_dir): if args.use_ema: @@ -914,7 +924,7 @@ def preprocess_train(examples): train_dataset_with_vae = train_dataset.map( compute_vae_encodings_fn, batched=True, - batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps, + batch_size=args.train_batch_size, new_fingerprint=new_fingerprint_for_vae, ) precomputed_dataset = concatenate_datasets( @@ -1160,7 +1170,8 @@ def compute_time_ids(original_size, crops_coords_top_left): accelerator.log({"train_loss": train_loss}, step=global_step) train_loss = 0.0 - if accelerator.is_main_process: + # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues. + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: diff --git a/examples/vqgan/README.md b/examples/vqgan/README.md new file mode 100644 index 000000000000..0b0f3589baf5 --- /dev/null +++ b/examples/vqgan/README.md @@ -0,0 +1,127 @@ +## Training an VQGAN VAE +VQVAEs were first introduced in [Neural Discrete Representation Learning](https://arxiv.org/abs/1711.00937) and was combined with a GAN in the paper [Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2012.09841). The basic idea of a VQVAE is it's a type of a variational auto encoder with tokens as the latent space similar to tokens for LLMs. This script was adapted from a [pr to huggingface's open-muse project](https://github.com/huggingface/open-muse/pull/52) with general code following [lucidrian's implementation of the vqgan training script](https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/trainers.py) but both of these implementation follow from the [taming transformer repo](https://github.com/CompVis/taming-transformers?tab=readme-ov-file). + + +Creating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets). + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install . +``` + +Then cd in the example folder and run +```bash +pip install -r requirements.txt +``` + + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +### Training on CIFAR10 + +The command to train a VQGAN model on cifar10 dataset: + +```bash +accelerate launch train_vqgan.py \ + --dataset_name=cifar10 \ + --image_column=img \ + --validation_images images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \ + --resolution=128 \ + --train_batch_size=2 \ + --gradient_accumulation_steps=8 \ + --report_to=wandb +``` + +An example training run is [here](https://wandb.ai/sayakpaul/vqgan-training/runs/0m5kzdfp) by @sayakpaul and a lower scale one [here](https://wandb.ai/dsbuddy27/vqgan-training/runs/eqd6xi4n?nw=nwuserisamu). The validation images can be obtained from [here](https://huggingface.co/datasets/diffusers/docs-images/tree/main/vqgan_validation_images). +The simplest way to improve the quality of a VQGAN model is to maximize the amount of information present in the bottleneck. The easiest way to do this is increasing the image resolution. However, other ways include, but not limited to, lowering compression by downsampling fewer times or increasing the vocaburary size which at most can be around 16384. How to do this is shown below. + +# Modifying the architecture + +To modify the architecture of the vqgan model you can save the config taken from [here](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder/blob/main/movq/config.json) and then provide that to the script with the option --model_config_name_or_path. This config is below +``` +{ + "_class_name": "VQModel", + "_diffusers_version": "0.17.0.dev0", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 256, + 512 + ], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "AttnDownEncoderBlock2D" + ], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 2, + "norm_num_groups": 32, + "norm_type": "spatial", + "num_vq_embeddings": 16384, + "out_channels": 3, + "sample_size": 32, + "scaling_factor": 0.18215, + "up_block_types": [ + "AttnUpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ], + "vq_embed_dim": 4 +} +``` +To lower the amount of layers in a VQGan, you can remove layers by modifying the block_out_channels, down_block_types, and up_block_types like below +``` +{ + "_class_name": "VQModel", + "_diffusers_version": "0.17.0.dev0", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 256, + ], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 2, + "norm_num_groups": 32, + "norm_type": "spatial", + "num_vq_embeddings": 16384, + "out_channels": 3, + "sample_size": 32, + "scaling_factor": 0.18215, + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ], + "vq_embed_dim": 4 +} +``` +For increasing the size of the vocaburaries you can increase num_vq_embeddings. However, [some research](https://magvit.cs.cmu.edu/v2/) shows that the representation of VQGANs start degrading after 2^14~16384 vq embeddings so it's not recommended to go past that. + +## Extra training tips/ideas +During logging take care to make sure data_time is low. data_time is the amount spent loading the data and where the GPU is not active. So essentially, it's the time wasted. The easiest way to lower data time is to increase the --dataloader_num_workers to a higher number like 4. Due to a bug in Pytorch, this only works on linux based systems. For more details check [here](https://github.com/huggingface/diffusers/issues/7646) +Secondly, training should seem to be done when both the discriminator and the generator loss converges. +Thirdly, another low hanging fruit is just using ema using the --use_ema parameter. This tends to make the output images smoother. This has a con where you have to lower your batch size by 1 but it may be worth it. +Another more experimental low hanging fruit is changing from the vgg19 to different models for the lpips loss using the --timm_model_backend. If you do this, I recommend also changing the timm_model_layers parameter to the layer in your model which you think is best for representation. However, becareful with the feature map norms since this can easily overdominate the loss. \ No newline at end of file diff --git a/examples/vqgan/discriminator.py b/examples/vqgan/discriminator.py new file mode 100644 index 000000000000..eb31c3cb0165 --- /dev/null +++ b/examples/vqgan/discriminator.py @@ -0,0 +1,48 @@ +""" +Ported from Paella +""" + +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin + + +# Discriminator model ported from Paella https://github.com/dome272/Paella/blob/main/src_distributed/vqgan.py +class Discriminator(ModelMixin, ConfigMixin): + @register_to_config + def __init__(self, in_channels=3, cond_channels=0, hidden_channels=512, depth=6): + super().__init__() + d = max(depth - 3, 3) + layers = [ + nn.utils.spectral_norm( + nn.Conv2d(in_channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1) + ), + nn.LeakyReLU(0.2), + ] + for i in range(depth - 1): + c_in = hidden_channels // (2 ** max((d - i), 0)) + c_out = hidden_channels // (2 ** max((d - 1 - i), 0)) + layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) + layers.append(nn.InstanceNorm2d(c_out)) + layers.append(nn.LeakyReLU(0.2)) + self.encoder = nn.Sequential(*layers) + self.shuffle = nn.Conv2d( + (hidden_channels + cond_channels) if cond_channels > 0 else hidden_channels, 1, kernel_size=1 + ) + self.logits = nn.Sigmoid() + + def forward(self, x, cond=None): + x = self.encoder(x) + if cond is not None: + cond = cond.view( + cond.size(0), + cond.size(1), + 1, + 1, + ).expand(-1, -1, x.size(-2), x.size(-1)) + x = torch.cat([x, cond], dim=1) + x = self.shuffle(x) + x = self.logits(x) + return x diff --git a/examples/vqgan/requirements.txt b/examples/vqgan/requirements.txt new file mode 100644 index 000000000000..f204a70f1e0e --- /dev/null +++ b/examples/vqgan/requirements.txt @@ -0,0 +1,8 @@ +accelerate>=0.16.0 +torchvision +transformers>=4.25.1 +datasets +timm +numpy +tqdm +tensorboard \ No newline at end of file diff --git a/examples/vqgan/test_vqgan.py b/examples/vqgan/test_vqgan.py new file mode 100644 index 000000000000..664a7f7365b0 --- /dev/null +++ b/examples/vqgan/test_vqgan.py @@ -0,0 +1,395 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import shutil +import sys +import tempfile + +import torch + +from diffusers import VQModel +from diffusers.utils.testing_utils import require_timm + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +@require_timm +class TextToImage(ExamplesTestsAccelerate): + @property + def test_vqmodel_config(self): + return { + "_class_name": "VQModel", + "_diffusers_version": "0.17.0.dev0", + "act_fn": "silu", + "block_out_channels": [ + 32, + ], + "down_block_types": [ + "DownEncoderBlock2D", + ], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 2, + "norm_num_groups": 32, + "norm_type": "spatial", + "num_vq_embeddings": 32, + "out_channels": 3, + "sample_size": 32, + "scaling_factor": 0.18215, + "up_block_types": [ + "UpDecoderBlock2D", + ], + "vq_embed_dim": 4, + } + + @property + def test_discriminator_config(self): + return { + "_class_name": "Discriminator", + "_diffusers_version": "0.27.0.dev0", + "in_channels": 3, + "cond_channels": 0, + "hidden_channels": 8, + "depth": 4, + } + + def get_vq_and_discriminator_configs(self, tmpdir): + vqmodel_config_path = os.path.join(tmpdir, "vqmodel.json") + discriminator_config_path = os.path.join(tmpdir, "discriminator.json") + with open(vqmodel_config_path, "w") as fp: + json.dump(self.test_vqmodel_config, fp) + with open(discriminator_config_path, "w") as fp: + json.dump(self.test_discriminator_config, fp) + return vqmodel_config_path, discriminator_config_path + + def test_vqmodel(self): + with tempfile.TemporaryDirectory() as tmpdir: + vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) + test_args = f""" + examples/vqgan/train_vqgan.py + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 32 + --image_column image + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --model_config_name_or_path {vqmodel_config_path} + --discriminator_config_name_or_path {discriminator_config_path} + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue( + os.path.isfile(os.path.join(tmpdir, "discriminator", "diffusion_pytorch_model.safetensors")) + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "vqmodel", "diffusion_pytorch_model.safetensors"))) + + def test_vqmodel_checkpointing(self): + with tempfile.TemporaryDirectory() as tmpdir: + vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) + # Run training script with checkpointing + # max_train_steps == 4, checkpointing_steps == 2 + # Should create checkpoints at steps 2, 4 + + initial_run_args = f""" + examples/vqgan/train_vqgan.py + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 32 + --image_column image + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 4 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --model_config_name_or_path {vqmodel_config_path} + --discriminator_config_name_or_path {discriminator_config_path} + --checkpointing_steps=2 + --output_dir {tmpdir} + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4"}, + ) + + # check can run an intermediate checkpoint + model = VQModel.from_pretrained(tmpdir, subfolder="checkpoint-2/vqmodel") + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + _ = model(image) + + # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming + shutil.rmtree(os.path.join(tmpdir, "checkpoint-2")) + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4"}, + ) + + # Run training script for 2 total steps resuming from checkpoint 4 + + resume_run_args = f""" + examples/vqgan/train_vqgan.py + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 32 + --image_column image + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 6 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --model_config_name_or_path {vqmodel_config_path} + --discriminator_config_name_or_path {discriminator_config_path} + --checkpointing_steps=1 + --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} + --output_dir {tmpdir} + --seed=0 + """.split() + + run_command(self._launch_args + resume_run_args) + + # check can run new fully trained pipeline + model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel") + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + _ = model(image) + + # no checkpoint-2 -> check old checkpoints do not exist + # check new checkpoints exist + # In the current script, checkpointing_steps 1 is equivalent to checkpointing_steps 2 as after the generator gets trained for one step, + # the discriminator gets trained and loss and saving happens after that. Thus we do not expect to get a checkpoint-5 + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_vqmodel_checkpointing_use_ema(self): + with tempfile.TemporaryDirectory() as tmpdir: + vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) + # Run training script with checkpointing + # max_train_steps == 4, checkpointing_steps == 2 + # Should create checkpoints at steps 2, 4 + + initial_run_args = f""" + examples/vqgan/train_vqgan.py + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 32 + --image_column image + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 4 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --model_config_name_or_path {vqmodel_config_path} + --discriminator_config_name_or_path {discriminator_config_path} + --checkpointing_steps=2 + --output_dir {tmpdir} + --use_ema + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel") + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + _ = model(image) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4"}, + ) + + # check can run an intermediate checkpoint + model = VQModel.from_pretrained(tmpdir, subfolder="checkpoint-2/vqmodel") + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + _ = model(image) + + # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming + shutil.rmtree(os.path.join(tmpdir, "checkpoint-2")) + + # Run training script for 2 total steps resuming from checkpoint 4 + + resume_run_args = f""" + examples/vqgan/train_vqgan.py + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 32 + --image_column image + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 6 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --model_config_name_or_path {vqmodel_config_path} + --discriminator_config_name_or_path {discriminator_config_path} + --checkpointing_steps=1 + --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} + --output_dir {tmpdir} + --use_ema + --seed=0 + """.split() + + run_command(self._launch_args + resume_run_args) + + # check can run new fully trained pipeline + model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel") + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + _ = model(image) + + # no checkpoint-2 -> check old checkpoints do not exist + # check new checkpoints exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_vqmodel_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) + # Run training script with checkpointing + # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2 + # Should create checkpoints at steps 2, 4, 6 + # with checkpoint at step 2 deleted + + initial_run_args = f""" + examples/vqgan/train_vqgan.py + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 32 + --image_column image + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 6 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --model_config_name_or_path {vqmodel_config_path} + --discriminator_config_name_or_path {discriminator_config_path} + --output_dir {tmpdir} + --checkpointing_steps=2 + --checkpoints_total_limit=2 + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel") + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + _ = model(image) + + # check checkpoint directories exist + # checkpoint-2 should have been deleted + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"}) + + def test_vqmodel_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) + # Run training script with checkpointing + # max_train_steps == 4, checkpointing_steps == 2 + # Should create checkpoints at steps 2, 4 + + initial_run_args = f""" + examples/vqgan/train_vqgan.py + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 32 + --image_column image + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 4 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --model_config_name_or_path {vqmodel_config_path} + --discriminator_config_name_or_path {discriminator_config_path} + --checkpointing_steps=2 + --output_dir {tmpdir} + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel") + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + _ = model(image) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4"}, + ) + + # resume and we should try to checkpoint at 6, where we'll have to remove + # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint + + resume_run_args = f""" + examples/vqgan/train_vqgan.py + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 32 + --image_column image + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 8 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --model_config_name_or_path {vqmodel_config_path} + --discriminator_config_name_or_path {discriminator_config_path} + --output_dir {tmpdir} + --checkpointing_steps=2 + --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} + --checkpoints_total_limit=2 + --seed=0 + """.split() + + run_command(self._launch_args + resume_run_args) + + model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel") + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + _ = model(image) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-6", "checkpoint-8"}, + ) diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py new file mode 100644 index 000000000000..b7beee1f3b26 --- /dev/null +++ b/examples/vqgan/train_vqgan.py @@ -0,0 +1,1067 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import math +import os +import shutil +import time +from pathlib import Path + +import accelerate +import numpy as np +import PIL +import PIL.Image +import timm +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, ProjectConfiguration, set_seed +from datasets import load_dataset +from discriminator import Discriminator +from huggingface_hub import create_repo +from packaging import version +from PIL import Image +from timm.data import resolve_data_config +from timm.data.transforms_factory import create_transform +from torchvision import transforms +from tqdm import tqdm + +from diffusers import VQModel +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, is_wandb_available + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.27.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def _map_layer_to_idx(backbone, layers, offset=0): + """Maps set of layer names to indices of model. Ported from anomalib + + Returns: + Feature map extracted from the CNN + """ + idx = [] + features = timm.create_model( + backbone, + pretrained=False, + features_only=False, + exportable=True, + ) + for i in layers: + try: + idx.append(list(dict(features.named_children()).keys()).index(i) - offset) + except ValueError: + raise ValueError( + f"Layer {i} not found in model {backbone}. Select layer from {list(dict(features.named_children()).keys())}. The network architecture is {features}" + ) + return idx + + +def get_perceptual_loss(pixel_values, fmap, timm_model, timm_model_resolution, timm_model_normalization): + img_timm_model_input = timm_model_normalization(F.interpolate(pixel_values, timm_model_resolution)) + fmap_timm_model_input = timm_model_normalization(F.interpolate(fmap, timm_model_resolution)) + + if pixel_values.shape[1] == 1: + # handle grayscale for timm_model + img_timm_model_input, fmap_timm_model_input = ( + t.repeat(1, 3, 1, 1) for t in (img_timm_model_input, fmap_timm_model_input) + ) + + img_timm_model_feats = timm_model(img_timm_model_input) + recon_timm_model_feats = timm_model(fmap_timm_model_input) + perceptual_loss = F.mse_loss(img_timm_model_feats[0], recon_timm_model_feats[0]) + for i in range(1, len(img_timm_model_feats)): + perceptual_loss += F.mse_loss(img_timm_model_feats[i], recon_timm_model_feats[i]) + perceptual_loss /= len(img_timm_model_feats) + return perceptual_loss + + +def grad_layer_wrt_loss(loss, layer): + return torch.autograd.grad( + outputs=loss, + inputs=layer, + grad_outputs=torch.ones_like(loss), + retain_graph=True, + )[0].detach() + + +def gradient_penalty(images, output, weight=10): + gradients = torch.autograd.grad( + outputs=output, + inputs=images, + grad_outputs=torch.ones(output.size(), device=images.device), + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + bsz = gradients.shape[0] + gradients = torch.reshape(gradients, (bsz, -1)) + return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean() + + +@torch.no_grad() +def log_validation(model, args, validation_transform, accelerator, global_step): + logger.info("Generating images...") + dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + dtype = torch.bfloat16 + original_images = [] + for image_path in args.validation_images: + image = PIL.Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = validation_transform(image).to(accelerator.device, dtype=dtype) + original_images.append(image[None]) + # Generate images + model.eval() + images = [] + for original_image in original_images: + image = accelerator.unwrap_model(model)(original_image).sample + images.append(image) + model.train() + original_images = torch.cat(original_images, dim=0) + images = torch.cat(images, dim=0) + + # Convert to PIL images + images = torch.clamp(images, 0.0, 1.0) + original_images = torch.clamp(original_images, 0.0, 1.0) + images *= 255.0 + original_images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + original_images = original_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + images = np.concatenate([original_images, images], axis=2) + images = [Image.fromarray(image) for image in images] + + # Log images + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: Original, Generated") for i, image in enumerate(images) + ] + }, + step=global_step, + ) + torch.cuda.empty_cache() + return images + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--log_grad_norm_steps", + type=int, + default=500, + help=("Print logs of gradient norms every X steps."), + ) + parser.add_argument( + "--log_steps", + type=int, + default=50, + help=("Print logs every X steps."), + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running reconstruction on images in" + " `args.validation_images` and logging the reconstructed images." + ), + ) + parser.add_argument( + "--vae_loss", + type=str, + default="l2", + help="The loss function for vae reconstruction loss.", + ) + parser.add_argument( + "--timm_model_offset", + type=int, + default=0, + help="Offset of timm layers to indices.", + ) + parser.add_argument( + "--timm_model_layers", + type=str, + default="head", + help="The layers to get output from in the timm model.", + ) + parser.add_argument( + "--timm_model_backend", + type=str, + default="vgg19", + help="Timm model used to get the lpips loss", + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--model_config_name_or_path", + type=str, + default=None, + help="The config of the Vq model to train, leave as None to use standard Vq model configuration.", + ) + parser.add_argument( + "--discriminator_config_name_or_path", + type=str, + default=None, + help="The config of the discriminator model to train, leave as None to use standard Vq model configuration.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_images", + type=str, + default=None, + nargs="+", + help=("A set of validation images evaluated every `--validation_steps` and logged to `--report_to`."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="vqgan-output", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--discr_learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--discr_lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="vqgan-training", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +def main(): + ######################### + # SETUP Accelerator # + ######################### + args = parse_args() + + # Enable TF32 on Ampere GPUs + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + tracker_config.pop("validation_images") + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + + if args.model_config_name_or_path is None and args.pretrained_model_name_or_path is None: + # Taken from config of movq at kandinsky-community/kandinsky-2-2-decoder but without the attention layers + model = VQModel( + act_fn="silu", + block_out_channels=[ + 128, + 256, + 512, + ], + down_block_types=[ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ], + in_channels=3, + latent_channels=4, + layers_per_block=2, + norm_num_groups=32, + norm_type="spatial", + num_vq_embeddings=16384, + out_channels=3, + sample_size=32, + scaling_factor=0.18215, + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + vq_embed_dim=4, + ) + elif args.pretrained_model_name_or_path is not None: + model = VQModel.from_pretrained(args.pretrained_model_name_or_path) + else: + config = VQModel.load_config(args.model_config_name_or_path) + model = VQModel.from_config(config) + if args.use_ema: + ema_model = EMAModel(model.parameters(), model_cls=VQModel, model_config=model.config) + if args.discriminator_config_name_or_path is None: + discriminator = Discriminator() + else: + config = Discriminator.load_config(args.discriminator_config_name_or_path) + discriminator = Discriminator.from_config(config) + + idx = _map_layer_to_idx(args.timm_model_backend, args.timm_model_layers.split("|"), args.timm_model_offset) + + timm_model = timm.create_model( + args.timm_model_backend, + pretrained=True, + features_only=True, + exportable=True, + out_indices=idx, + ) + timm_model = timm_model.to(accelerator.device) + timm_model.requires_grad = False + timm_model.eval() + timm_transform = create_transform(**resolve_data_config(timm_model.pretrained_cfg, model=timm_model)) + try: + # Gets the resolution of the timm transformation after centercrop + timm_centercrop_transform = timm_transform.transforms[1] + assert isinstance( + timm_centercrop_transform, transforms.CenterCrop + ), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19." + timm_model_resolution = timm_centercrop_transform.size[0] + # Gets final normalization + timm_model_normalization = timm_transform.transforms[-1] + assert isinstance( + timm_model_normalization, transforms.Normalize + ), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19." + except AssertionError as e: + raise NotImplementedError(e) + # Enable flash attention if asked + if args.enable_xformers_memory_efficient_attention: + model.enable_xformers_memory_efficient_attention() + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + if args.use_ema: + ema_model.save_pretrained(os.path.join(output_dir, "vqmodel_ema")) + vqmodel = models[0] + discriminator = models[1] + vqmodel.save_pretrained(os.path.join(output_dir, "vqmodel")) + discriminator.save_pretrained(os.path.join(output_dir, "discriminator")) + weights.pop() + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "vqmodel_ema"), VQModel) + ema_model.load_state_dict(load_model.state_dict()) + ema_model.to(accelerator.device) + del load_model + discriminator = models.pop() + load_model = Discriminator.from_pretrained(input_dir, subfolder="discriminator") + discriminator.register_to_config(**load_model.config) + discriminator.load_state_dict(load_model.state_dict()) + del load_model + vqmodel = models.pop() + load_model = VQModel.from_pretrained(input_dir, subfolder="vqmodel") + vqmodel.register_to_config(**load_model.config) + vqmodel.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + learning_rate = args.learning_rate + if args.scale_lr: + learning_rate = ( + learning_rate * args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + optimizer = optimizer_cls( + list(model.parameters()), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + discr_optimizer = optimizer_cls( + list(discriminator.parameters()), + lr=args.discr_learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + ################################## + # DATLOADER and LR-SCHEDULER # + ################################# + logger.info("Creating dataloaders and lr_scheduler") + + args.train_batch_size * accelerator.num_processes + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + # DataLoaders creation: + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + data_dir=args.train_data_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + assert args.image_column is not None + image_column = args.image_column + if image_column not in column_names: + raise ValueError(f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}") + # Preprocessing the datasets. + train_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + ] + ) + validation_transform = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["pixel_values"] = [train_transforms(image) for image in images] + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + return {"pixel_values": pixel_values} + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_training_steps=args.max_train_steps, + num_warmup_steps=args.lr_warmup_steps, + ) + discr_lr_scheduler = get_scheduler( + args.discr_lr_scheduler, + optimizer=discr_optimizer, + num_training_steps=args.max_train_steps, + num_warmup_steps=args.lr_warmup_steps, + ) + + # Prepare everything with accelerator + logger.info("Preparing model, optimizer and dataloaders") + # The dataloader are already aware of distributed training, so we don't need to prepare them. + model, discriminator, optimizer, discr_optimizer, lr_scheduler, discr_lr_scheduler = accelerator.prepare( + model, discriminator, optimizer, discr_optimizer, lr_scheduler, discr_lr_scheduler + ) + if args.use_ema: + ema_model.to(accelerator.device) + # Train! + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Potentially load in the weights and states from a previous save + resume_from_checkpoint = args.resume_from_checkpoint + if resume_from_checkpoint: + if resume_from_checkpoint != "latest": + path = resume_from_checkpoint + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + path = os.path.join(args.output_dir, path) + + if path is None: + accelerator.print(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.") + resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(path) + accelerator.wait_for_everyone() + global_step = int(os.path.basename(path).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + # As stated above, we are not doing epoch based training here, but just using this for book keeping and being able to + # reuse the same training loop with other datasets/loaders. + avg_gen_loss, avg_discr_loss = None, None + for epoch in range(first_epoch, args.num_train_epochs): + model.train() + discriminator.train() + for i, batch in enumerate(train_dataloader): + pixel_values = batch["pixel_values"] + pixel_values = pixel_values.to(accelerator.device, non_blocking=True) + data_time_m.update(time.time() - end) + generator_step = ((i // args.gradient_accumulation_steps) % 2) == 0 + # Train Step + # The behavior of accelerator.accumulate is to + # 1. Check if gradients are synced(reached gradient-accumulation_steps) + # 2. If so sync gradients by stopping the not syncing process + if generator_step: + optimizer.zero_grad(set_to_none=True) + else: + discr_optimizer.zero_grad(set_to_none=True) + # encode images to the latent space and get the commit loss from vq tokenization + # Return commit loss + fmap, commit_loss = model(pixel_values, return_dict=False) + + if generator_step: + with accelerator.accumulate(model): + # reconstruction loss. Pixel level differences between input vs output + if args.vae_loss == "l2": + loss = F.mse_loss(pixel_values, fmap) + else: + loss = F.l1_loss(pixel_values, fmap) + # perceptual loss. The high level feature mean squared error loss + perceptual_loss = get_perceptual_loss( + pixel_values, + fmap, + timm_model, + timm_model_resolution=timm_model_resolution, + timm_model_normalization=timm_model_normalization, + ) + # generator loss + gen_loss = -discriminator(fmap).mean() + last_dec_layer = accelerator.unwrap_model(model).decoder.conv_out.weight + norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p=2) + norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2) + + adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-8) + adaptive_weight = adaptive_weight.clamp(max=1e4) + loss += commit_loss + loss += perceptual_loss + loss += adaptive_weight * gen_loss + # Gather the losses across all processes for logging (if we use distributed training). + avg_gen_loss = accelerator.gather(loss.repeat(args.train_batch_size)).float().mean() + accelerator.backward(loss) + + if args.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + # log gradient norm before zeroing it + if ( + accelerator.sync_gradients + and global_step % args.log_grad_norm_steps == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step) + else: + # Return discriminator loss + with accelerator.accumulate(discriminator): + fmap.detach_() + pixel_values.requires_grad_() + real = discriminator(pixel_values) + fake = discriminator(fmap) + loss = (F.relu(1 + fake) + F.relu(1 - real)).mean() + gp = gradient_penalty(pixel_values, real) + loss += gp + avg_discr_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + accelerator.backward(loss) + + if args.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(discriminator.parameters(), args.max_grad_norm) + + discr_optimizer.step() + discr_lr_scheduler.step() + if ( + accelerator.sync_gradients + and global_step % args.log_grad_norm_steps == 0 + and accelerator.is_main_process + ): + log_grad_norm(discriminator, accelerator, global_step) + batch_time_m.update(time.time() - end) + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + global_step += 1 + progress_bar.update(1) + if args.use_ema: + ema_model.step(model.parameters()) + if accelerator.sync_gradients and not generator_step and accelerator.is_main_process: + # wait for both generator and discriminator to settle + # Log metrics + if global_step % args.log_steps == 0: + samples_per_second_per_gpu = ( + args.gradient_accumulation_steps * args.train_batch_size / batch_time_m.val + ) + logs = { + "step_discr_loss": avg_discr_loss.item(), + "lr": lr_scheduler.get_last_lr()[0], + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + if avg_gen_loss is not None: + logs["step_gen_loss"] = avg_gen_loss.item() + accelerator.log(logs, step=global_step) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + # Save model checkpoint + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + # Generate images + if global_step % args.validation_steps == 0: + if args.use_ema: + # Store the VQGAN parameters temporarily and load the EMA parameters to perform inference. + ema_model.store(model.parameters()) + ema_model.copy_to(model.parameters()) + log_validation(model, args, validation_transform, accelerator, global_step) + if args.use_ema: + # Switch back to the original VQGAN parameters. + ema_model.restore(model.parameters()) + end = time.time() + # Stop training if max steps is reached + if global_step >= args.max_train_steps: + break + # End for + + accelerator.wait_for_everyone() + + # Save the final trained checkpoint + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + discriminator = accelerator.unwrap_model(discriminator) + if args.use_ema: + ema_model.copy_to(model.parameters()) + model.save_pretrained(os.path.join(args.output_dir, "vqmodel")) + discriminator.save_pretrained(os.path.join(args.output_dir, "discriminator")) + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/callbacks.py b/src/diffusers/callbacks.py new file mode 100644 index 000000000000..38542407e31f --- /dev/null +++ b/src/diffusers/callbacks.py @@ -0,0 +1,156 @@ +from typing import Any, Dict, List + +from .configuration_utils import ConfigMixin, register_to_config +from .utils import CONFIG_NAME + + +class PipelineCallback(ConfigMixin): + """ + Base class for all the official callbacks used in a pipeline. This class provides a structure for implementing + custom callbacks and ensures that all callbacks have a consistent interface. + + Please implement the following: + `tensor_inputs`: This should return a list of tensor inputs specific to your callback. You will only be able to + include + variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + `callback_fn`: This method defines the core functionality of your callback. + """ + + config_name = CONFIG_NAME + + @register_to_config + def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None): + super().__init__() + + if (cutoff_step_ratio is None and cutoff_step_index is None) or ( + cutoff_step_ratio is not None and cutoff_step_index is not None + ): + raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.") + + if cutoff_step_ratio is not None and ( + not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0) + ): + raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.") + + @property + def tensor_inputs(self) -> List[str]: + raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}") + + def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]: + raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}") + + def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + return self.callback_fn(pipeline, step_index, timestep, callback_kwargs) + + +class MultiPipelineCallbacks: + """ + This class is designed to handle multiple pipeline callbacks. It accepts a list of PipelineCallback objects and + provides a unified interface for calling all of them. + """ + + def __init__(self, callbacks: List[PipelineCallback]): + self.callbacks = callbacks + + @property + def tensor_inputs(self) -> List[str]: + return [input for callback in self.callbacks for input in callback.tensor_inputs] + + def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + """ + Calls all the callbacks in order with the given arguments and returns the final callback_kwargs. + """ + for callback in self.callbacks: + callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs) + + return callback_kwargs + + +class SDCFGCutoffCallback(PipelineCallback): + """ + Callback function for Stable Diffusion Pipelines. After certain number of steps (set by `cutoff_step_ratio` or + `cutoff_step_index`), this callback will disable the CFG. + + Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. + """ + + tensor_inputs = ["prompt_embeds"] + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + prompt_embeds = callback_kwargs[self.tensor_inputs[0]] + prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens. + + pipeline._guidance_scale = 0.0 + + callback_kwargs[self.tensor_inputs[0]] = prompt_embeds + return callback_kwargs + + +class SDXLCFGCutoffCallback(PipelineCallback): + """ + Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or + `cutoff_step_index`), this callback will disable the CFG. + + Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. + """ + + tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"] + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + prompt_embeds = callback_kwargs[self.tensor_inputs[0]] + prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens. + + add_text_embeds = callback_kwargs[self.tensor_inputs[1]] + add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens + + add_time_ids = callback_kwargs[self.tensor_inputs[2]] + add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector + + pipeline._guidance_scale = 0.0 + + callback_kwargs[self.tensor_inputs[0]] = prompt_embeds + callback_kwargs[self.tensor_inputs[1]] = add_text_embeds + callback_kwargs[self.tensor_inputs[2]] = add_time_ids + return callback_kwargs + + +class IPAdapterScaleCutoffCallback(PipelineCallback): + """ + Callback function for any pipeline that inherits `IPAdapterMixin`. After certain number of steps (set by + `cutoff_step_ratio` or `cutoff_step_index`), this callback will set the IP Adapter scale to `0.0`. + + Note: This callback mutates the IP Adapter attention processors by setting the scale to 0.0 after the cutoff step. + """ + + tensor_inputs = [] + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + pipeline.set_ip_adapter_scale(0.0) + return callback_kwargs diff --git a/src/diffusers/commands/env.py b/src/diffusers/commands/env.py index baa69b361f5d..024b5e6ec786 100644 --- a/src/diffusers/commands/env.py +++ b/src/diffusers/commands/env.py @@ -13,12 +13,24 @@ # limitations under the License. import platform +import subprocess from argparse import ArgumentParser import huggingface_hub from .. import __version__ as version -from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available +from ..utils import ( + is_accelerate_available, + is_bitsandbytes_available, + is_flax_available, + is_google_colab, + is_notebook, + is_peft_available, + is_safetensors_available, + is_torch_available, + is_transformers_available, + is_xformers_available, +) from . import BaseDiffusersCLICommand @@ -28,13 +40,19 @@ def info_command_factory(_): class EnvironmentCommand(BaseDiffusersCLICommand): @staticmethod - def register_subcommand(parser: ArgumentParser): + def register_subcommand(parser: ArgumentParser) -> None: download_parser = parser.add_parser("env") download_parser.set_defaults(func=info_command_factory) - def run(self): + def run(self) -> dict: hub_version = huggingface_hub.__version__ + safetensors_version = "not installed" + if is_safetensors_available(): + import safetensors + + safetensors_version = safetensors.__version__ + pt_version = "not installed" pt_cuda_available = "NA" if is_torch_available(): @@ -43,6 +61,20 @@ def run(self): pt_version = torch.__version__ pt_cuda_available = torch.cuda.is_available() + flax_version = "not installed" + jax_version = "not installed" + jaxlib_version = "not installed" + jax_backend = "NA" + if is_flax_available(): + import flax + import jax + import jaxlib + + flax_version = flax.__version__ + jax_version = jax.__version__ + jaxlib_version = jaxlib.__version__ + jax_backend = jax.lib.xla_bridge.get_backend().platform + transformers_version = "not installed" if is_transformers_available(): import transformers @@ -55,21 +87,87 @@ def run(self): accelerate_version = accelerate.__version__ + peft_version = "not installed" + if is_peft_available(): + import peft + + peft_version = peft.__version__ + + bitsandbytes_version = "not installed" + if is_bitsandbytes_available(): + import bitsandbytes + + bitsandbytes_version = bitsandbytes.__version__ + xformers_version = "not installed" if is_xformers_available(): import xformers xformers_version = xformers.__version__ + is_notebook_str = "Yes" if is_notebook() else "No" + + is_google_colab_str = "Yes" if is_google_colab() else "No" + + accelerator = "NA" + if platform.system() in {"Linux", "Windows"}: + try: + sp = subprocess.Popen( + ["nvidia-smi", "--query-gpu=gpu_name,memory.total", "--format=csv,noheader"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out_str, _ = sp.communicate() + out_str = out_str.decode("utf-8") + + if len(out_str) > 0: + accelerator = out_str.strip() + " VRAM" + except FileNotFoundError: + pass + elif platform.system() == "Darwin": # Mac OS + try: + sp = subprocess.Popen( + ["system_profiler", "SPDisplaysDataType"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out_str, _ = sp.communicate() + out_str = out_str.decode("utf-8") + + start = out_str.find("Chipset Model:") + if start != -1: + start += len("Chipset Model:") + end = out_str.find("\n", start) + accelerator = out_str[start:end].strip() + + start = out_str.find("VRAM (Total):") + if start != -1: + start += len("VRAM (Total):") + end = out_str.find("\n", start) + accelerator += " VRAM: " + out_str[start:end].strip() + except FileNotFoundError: + pass + else: + print("It seems you are running an unusual OS. Could you fill in the accelerator manually?") + info = { - "`diffusers` version": version, - "Platform": platform.platform(), + "🤗 Diffusers version": version, + "Platform": f"{platform.freedesktop_os_release().get('PRETTY_NAME', None)} - {platform.platform()}", + "Running on a notebook?": is_notebook_str, + "Running on Google Colab?": is_google_colab_str, "Python version": platform.python_version(), "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", + "Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})", + "Jax version": jax_version, + "JaxLib version": jaxlib_version, "Huggingface_hub version": hub_version, "Transformers version": transformers_version, "Accelerate version": accelerate_version, + "PEFT version": peft_version, + "Bitsandbytes version": bitsandbytes_version, + "Safetensors version": safetensors_version, "xFormers version": xformers_version, + "Accelerator": accelerator, "Using GPU in script?": "", "Using distributed or parallel set-up in script?": "", } @@ -80,5 +178,5 @@ def run(self): return info @staticmethod - def format_dict(d): + def format_dict(d: dict) -> str: return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n" diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 97ea0d9e589c..e089d202ee75 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -363,7 +363,7 @@ def _optionally_disable_offloading(cls, _pipeline): is_model_cpu_offload = False is_sequential_cpu_offload = False - if _pipeline is not None: + if _pipeline is not None and _pipeline.hf_device_map is None: for _, component in _pipeline.components.items(): if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): if not is_model_cpu_offload: diff --git a/src/diffusers/loaders/textual_inversion.py b/src/diffusers/loaders/textual_inversion.py index a9b9a9aae052..b6e1545e16dd 100644 --- a/src/diffusers/loaders/textual_inversion.py +++ b/src/diffusers/loaders/textual_inversion.py @@ -419,19 +419,20 @@ def load_textual_inversion( # 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again is_model_cpu_offload = False is_sequential_cpu_offload = False - for _, component in self.components.items(): - if isinstance(component, nn.Module): - if hasattr(component, "_hf_hook"): - is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) - is_sequential_cpu_offload = ( - isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) - or hasattr(component._hf_hook, "hooks") - and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) - ) - logger.info( - "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again." - ) - remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + if self.hf_device_map is None: + for _, component in self.components.items(): + if isinstance(component, nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = ( + isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) + logger.info( + "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again." + ) + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) # 7.2 save expected device and dtype device = text_encoder.device diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py index 333842905bc3..bb80ce8605ba 100644 --- a/src/diffusers/models/autoencoders/vae.py +++ b/src/diffusers/models/autoencoders/vae.py @@ -41,6 +41,7 @@ class DecoderOutput(BaseOutput): """ sample: torch.Tensor + commit_loss: Optional[torch.FloatTensor] = None class Encoder(nn.Module): diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py new file mode 100644 index 000000000000..635cd0ba5728 --- /dev/null +++ b/src/diffusers/models/model_loading_utils.py @@ -0,0 +1,149 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +from collections import OrderedDict +from typing import List, Optional, Union + +import safetensors +import torch + +from ..utils import ( + SAFETENSORS_FILE_EXTENSION, + is_accelerate_available, + is_torch_version, + logging, +) + + +logger = logging.get_logger(__name__) + + +if is_accelerate_available(): + from accelerate import infer_auto_device_map + from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device + + +# Adapted from `transformers` (see modeling_utils.py) +def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype): + if isinstance(device_map, str): + no_split_modules = model._get_no_split_modules(device_map) + device_map_kwargs = {"no_split_module_classes": no_split_modules} + + if device_map != "sequential": + max_memory = get_balanced_memory( + model, + dtype=torch_dtype, + low_zero=(device_map == "balanced_low_0"), + max_memory=max_memory, + **device_map_kwargs, + ) + else: + max_memory = get_max_memory(max_memory) + + device_map_kwargs["max_memory"] = max_memory + device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs) + + return device_map + + +def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): + """ + Reads a checkpoint file, returning properly formatted errors if they arise. + """ + try: + file_extension = os.path.basename(checkpoint_file).split(".")[-1] + if file_extension == SAFETENSORS_FILE_EXTENSION: + return safetensors.torch.load_file(checkpoint_file, device="cpu") + else: + weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} + return torch.load( + checkpoint_file, + map_location="cpu", + **weights_only_kwarg, + ) + except Exception as e: + try: + with open(checkpoint_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " + "model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " + ) + + +def load_model_dict_into_meta( + model, + state_dict: OrderedDict, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + model_name_or_path: Optional[str] = None, +) -> List[str]: + device = device or torch.device("cpu") + dtype = dtype or torch.float32 + + accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) + + unexpected_keys = [] + empty_state_dict = model.state_dict() + for param_name, param in state_dict.items(): + if param_name not in empty_state_dict: + unexpected_keys.append(param_name) + continue + + if empty_state_dict[param_name].shape != param.shape: + model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" + raise ValueError( + f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." + ) + + if accepts_dtype: + set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) + else: + set_module_tensor_to_device(model, param_name, device, value=param) + return unexpected_keys + + +def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]: + # Convert old format to new format if needed from a PyTorch state_dict + # copy state_dict so _load_from_state_dict can modify it + state_dict = state_dict.copy() + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: torch.nn.Module, prefix: str = ""): + args = (state_dict, prefix, {}, True, [], [], error_msgs) + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(model_to_load) + + return error_msgs diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index e8ba1ed65de3..2ed5655c84a9 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -33,7 +33,6 @@ from ..utils import ( CONFIG_NAME, FLAX_WEIGHTS_NAME, - SAFETENSORS_FILE_EXTENSION, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, _add_variant, @@ -44,6 +43,12 @@ logging, ) from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card +from .model_loading_utils import ( + _determine_device_map, + _load_state_dict_into_model, + load_model_dict_into_meta, + load_state_dict, +) logger = logging.get_logger(__name__) @@ -57,9 +62,6 @@ if is_accelerate_available(): import accelerate - from accelerate import infer_auto_device_map - from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device - from accelerate.utils.versions import is_torch_version def get_parameter_device(parameter: torch.nn.Module) -> torch.device: @@ -100,117 +102,6 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: return first_tuple[1].dtype -# Adapted from `transformers` (see modeling_utils.py) -def _determine_device_map(model: "ModelMixin", device_map, max_memory, torch_dtype): - if isinstance(device_map, str): - no_split_modules = model._get_no_split_modules(device_map) - device_map_kwargs = {"no_split_module_classes": no_split_modules} - - if device_map != "sequential": - max_memory = get_balanced_memory( - model, - dtype=torch_dtype, - low_zero=(device_map == "balanced_low_0"), - max_memory=max_memory, - **device_map_kwargs, - ) - else: - max_memory = get_max_memory(max_memory) - - device_map_kwargs["max_memory"] = max_memory - device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs) - - return device_map - - -def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): - """ - Reads a checkpoint file, returning properly formatted errors if they arise. - """ - try: - file_extension = os.path.basename(checkpoint_file).split(".")[-1] - if file_extension == SAFETENSORS_FILE_EXTENSION: - return safetensors.torch.load_file(checkpoint_file, device="cpu") - else: - weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} - return torch.load( - checkpoint_file, - map_location="cpu", - **weights_only_kwarg, - ) - except Exception as e: - try: - with open(checkpoint_file) as f: - if f.read().startswith("version"): - raise OSError( - "You seem to have cloned a repository without having git-lfs installed. Please install " - "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " - "you cloned." - ) - else: - raise ValueError( - f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " - "model. Make sure you have saved the model properly." - ) from e - except (UnicodeDecodeError, ValueError): - raise OSError( - f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " - ) - - -def load_model_dict_into_meta( - model, - state_dict: OrderedDict, - device: Optional[Union[str, torch.device]] = None, - dtype: Optional[Union[str, torch.dtype]] = None, - model_name_or_path: Optional[str] = None, -) -> List[str]: - device = device or torch.device("cpu") - dtype = dtype or torch.float32 - - accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) - - unexpected_keys = [] - empty_state_dict = model.state_dict() - for param_name, param in state_dict.items(): - if param_name not in empty_state_dict: - unexpected_keys.append(param_name) - continue - - if empty_state_dict[param_name].shape != param.shape: - model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" - raise ValueError( - f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." - ) - - if accepts_dtype: - set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) - else: - set_module_tensor_to_device(model, param_name, device, value=param) - return unexpected_keys - - -def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]: - # Convert old format to new format if needed from a PyTorch state_dict - # copy state_dict so _load_from_state_dict can modify it - state_dict = state_dict.copy() - error_msgs = [] - - # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants - # so we need to apply the function recursively. - def load(module: torch.nn.Module, prefix: str = ""): - args = (state_dict, prefix, {}, True, [], [], error_msgs) - module._load_from_state_dict(*args) - - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + ".") - - load(model_to_load) - - return error_msgs - - class ModelMixin(torch.nn.Module, PushToHubMixin): r""" Base class for all models. diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index d07100b10edd..ad45a43b5023 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -685,7 +685,7 @@ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: i positive_len = 768 if isinstance(cross_attention_dim, int): positive_len = cross_attention_dim - elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): + elif isinstance(cross_attention_dim, (list, tuple)): positive_len = cross_attention_dim[0] feature_type = "text-only" if attention_type == "gated" else "text-image" diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index a092daa662f7..1b62d16d5d77 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config @@ -27,6 +28,9 @@ AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, + AttnProcessor2_0, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin @@ -490,6 +494,36 @@ def from_unet2d( model.time_proj.load_state_dict(unet.time_proj.state_dict()) model.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + if any( + isinstance(proc, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)) + for proc in unet.attn_processors.values() + ): + attn_procs = {} + for name, processor in unet.attn_processors.items(): + if name.endswith("attn1.processor"): + attn_processor_class = ( + AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor + ) + attn_procs[name] = attn_processor_class() + else: + attn_processor_class = ( + IPAdapterAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else IPAdapterAttnProcessor + ) + attn_procs[name] = attn_processor_class( + hidden_size=processor.hidden_size, + cross_attention_dim=processor.cross_attention_dim, + scale=processor.scale, + num_tokens=processor.num_tokens, + ) + for name, processor in model.attn_processors.items(): + if name not in attn_procs: + attn_procs[name] = processor.__class__() + model.set_attn_processor(attn_procs) + model.config.encoder_hid_dim_type = "ip_image_proj" + model.encoder_hid_proj = unet.encoder_hid_proj + for i, down_block in enumerate(unet.down_blocks): model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict()) if hasattr(model.down_blocks[i], "attentions"): diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index 2e38d5b6711a..cb32b1f40734 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -142,18 +142,20 @@ def decode( ) -> Union[DecoderOutput, torch.Tensor]: # also go through quantization layer if not force_not_quantize: - quant, _, _ = self.quantize(h) + quant, commit_loss, _ = self.quantize(h) elif self.config.lookup_from_codebook: quant = self.quantize.get_codebook_entry(h, shape) + commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype) else: quant = h + commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype) quant2 = self.post_quant_conv(quant) dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None) if not return_dict: - return (dec,) + return dec, commit_loss - return DecoderOutput(sample=dec) + return DecoderOutput(sample=dec, commit_loss=commit_loss) def forward( self, sample: torch.Tensor, return_dict: bool = True @@ -173,9 +175,8 @@ def forward( """ h = self.encode(sample).latents - dec = self.decode(h).sample + dec = self.decode(h) if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) + return dec.sample, dec.commit_loss + return dec diff --git a/src/diffusers/pipelines/controlnet/multicontrolnet.py b/src/diffusers/pipelines/controlnet/multicontrolnet.py index 98e9eec94a91..e3c5ec6eed03 100644 --- a/src/diffusers/pipelines/controlnet/multicontrolnet.py +++ b/src/diffusers/pipelines/controlnet/multicontrolnet.py @@ -100,20 +100,16 @@ def save_pretrained( variant (`str`, *optional*): If specified, weights are saved in the format pytorch_model..bin. """ - idx = 0 - model_path_to_save = save_directory - for controlnet in self.nets: + for idx, controlnet in enumerate(self.nets): + suffix = "" if idx == 0 else f"_{idx}" controlnet.save_pretrained( - model_path_to_save, + save_directory + suffix, is_main_process=is_main_process, save_function=save_function, safe_serialization=safe_serialization, variant=variant, ) - idx += 1 - model_path_to_save = model_path_to_save + f"_{idx}" - @classmethod def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs): r""" diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index e64dcdc55457..cf979c352cfb 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel @@ -926,7 +927,9 @@ def __call__( control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1019,11 +1022,11 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1055,6 +1058,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 2e44efa78b73..e6f1a06bddc2 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -21,6 +21,7 @@ import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel @@ -917,7 +918,9 @@ def __call__( control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1004,11 +1007,11 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1040,6 +1043,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index cdc34819d59e..d29e3ac8f90e 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -23,6 +23,7 @@ import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel @@ -1134,7 +1135,9 @@ def __call__( control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1239,11 +1242,11 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1275,6 +1278,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 3cfdefa9d44d..56c2dc10a570 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -27,6 +27,7 @@ CLIPVisionModelWithProjection, ) +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, @@ -197,8 +198,26 @@ class StableDiffusionXLControlNetInpaintPipeline( """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + "mask", + "masked_image_latents", + ] def __init__( self, @@ -1178,7 +1197,9 @@ def __call__( aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1317,11 +1338,11 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1351,6 +1372,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance @@ -1730,7 +1754,7 @@ def denoising_value_valid(dnv): down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) - if ip_adapter_image is not None: + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds if num_channels_unet == 9: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 73bb8be89eb1..f6992ac08c31 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -30,6 +30,7 @@ from diffusers.utils.import_utils import is_invisible_watermark_available +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, @@ -114,6 +115,66 @@ """ +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + class StableDiffusionXLControlNetPipeline( DiffusionPipeline, StableDiffusionMixin, @@ -175,7 +236,15 @@ class StableDiffusionXLControlNetPipeline( "feature_extractor", "image_encoder", ] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] def __init__( self, @@ -941,6 +1010,8 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -969,7 +1040,9 @@ def __call__( negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1001,6 +1074,14 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will @@ -1099,11 +1180,11 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1133,6 +1214,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance @@ -1265,8 +1349,9 @@ def __call__( assert False # 5. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) self._num_timesteps = len(timesteps) # 6. Prepare latent variables @@ -1451,6 +1536,12 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index dbd406d928d5..0059fb1a729f 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -30,6 +30,7 @@ from diffusers.utils.import_utils import is_invisible_watermark_available +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, @@ -227,7 +228,15 @@ class StableDiffusionXLControlNetImg2ImgPipeline( "feature_extractor", "image_encoder", ] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + ] def __init__( self, @@ -1105,7 +1114,9 @@ def __call__( aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1254,11 +1265,11 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1288,6 +1299,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance @@ -1578,6 +1592,12 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 50cd24e4fa18..3675a99ba67f 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -21,6 +21,7 @@ import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel @@ -648,7 +649,9 @@ def __call__( control_guidance_start: float = 0.0, control_guidance_end: float = 1.0, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], ): r""" @@ -715,11 +718,11 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -734,6 +737,9 @@ def __call__( "not-safe-for-work" (nsfw) content. """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet # 1. Check inputs. Raise error if not correct diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index e572412f6e91..24640a6067c0 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -28,6 +28,7 @@ from diffusers.utils.import_utils import is_invisible_watermark_available +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel @@ -157,7 +158,15 @@ class StableDiffusionXLControlNetXSPipeline( "text_encoder_2", "feature_extractor", ] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] def __init__( self, @@ -739,7 +748,9 @@ def __call__( negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], ): r""" @@ -851,11 +862,11 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -869,6 +880,9 @@ def __call__( returned, otherwise a `tuple` is returned containing the output images. """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet # 1. Check inputs. Raise error if not correct diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 9e172ec2dca6..e1ba3879f568 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -817,7 +817,7 @@ def __init__( positive_len = 768 if isinstance(cross_attention_dim, int): positive_len = cross_attention_dim - elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): + elif isinstance(cross_attention_dim, (list, tuple)): positive_len = cross_attention_dim[0] feature_type = "text-only" if attention_type == "gated" else "text-image" diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 9794b89f7611..6d3f5c1e274d 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -366,7 +366,7 @@ def encode_prompt( ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" + "The following part of your input was truncated because T5 can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index 5e3cc668d6a8..8b52bd487eb6 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -292,7 +292,7 @@ def encode_prompt( ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" + "The following part of your input was truncated because T5 can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py index 2ec60fbd619c..cd9ec57fb879 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -197,7 +197,7 @@ def check_inputs( ) # verify batch size of prompt and image are same if image is a list or tensor or numpy array - if isinstance(image, list) or isinstance(image, np.ndarray): + if isinstance(image, (list, np.ndarray)): if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index e8ab72421d7e..087ce151f0b4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import inspect from typing import Any, Callable, Dict, List, Optional, Union @@ -19,6 +18,7 @@ from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin @@ -775,7 +775,9 @@ def __call__( cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -845,11 +847,11 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -881,6 +883,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index f2a5de81540d..d806f230a06a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -21,6 +21,7 @@ from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin @@ -862,7 +863,9 @@ def __call__( return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: int = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -932,11 +935,11 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -967,6 +970,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 71dec964fdca..37a3b32994ff 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -21,6 +21,7 @@ from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin @@ -1014,7 +1015,9 @@ def __call__( return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: int = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1107,11 +1110,11 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1171,6 +1174,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index b2b2b14009db..1443c8b0af52 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -20,6 +20,7 @@ import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel @@ -175,7 +176,9 @@ def __call__( ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -227,11 +230,11 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -290,6 +293,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 0. Check inputs self.check_inputs( prompt, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py index 978e2dbb60ab..4d033133e5ec 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -221,7 +221,7 @@ def check_inputs(self, prompt, image, callback_steps): ) # verify batch size of prompt and image are same if image is a list or tensor - if isinstance(image, list) or isinstance(image, torch.Tensor): + if isinstance(image, (list, torch.Tensor)): if isinstance(prompt, str): batch_size = 1 else: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 3981c8a461fe..07aafe8821b8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -468,7 +468,7 @@ def check_inputs( ) # verify batch size of prompt and image are same if image is a list or tensor or numpy array - if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray): + if isinstance(image, (list, np.ndarray, torch.Tensor)): if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py index 1323542e40b1..967d525c7397 100644 --- a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py @@ -185,7 +185,7 @@ def preprocess(image): def preprocess_mask(mask, batch_size: int = 1): if not isinstance(mask, torch.Tensor): # preprocess mask - if isinstance(mask, PIL.Image.Image) or isinstance(mask, np.ndarray): + if isinstance(mask, (PIL.Image.Image, np.ndarray)): mask = [mask] if isinstance(mask, list): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 52d0b07fb315..2568150fa5f2 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -24,6 +24,7 @@ CLIPVisionModelWithProjection, ) +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, @@ -861,7 +862,9 @@ def __call__( negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -992,11 +995,11 @@ def __call__( as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1026,6 +1029,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index b8698a008320..838489dca778 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -25,6 +25,7 @@ CLIPVisionModelWithProjection, ) +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, @@ -1008,7 +1009,9 @@ def __call__( aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1157,11 +1160,11 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1191,6 +1194,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 38f5cec931f8..631e309993b1 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -26,6 +26,7 @@ CLIPVisionModelWithProjection, ) +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, @@ -1243,7 +1244,9 @@ def __call__( aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): @@ -1411,11 +1414,11 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the @@ -1445,6 +1448,9 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index 8b8da5d110d2..653171638ccf 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -347,11 +347,7 @@ def step( otherwise a tuple is returned where the first element is the sample tensor. """ - if ( - isinstance(timestep, int) - or isinstance(timestep, torch.IntTensor) - or isinstance(timestep, torch.LongTensor) - ): + if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" diff --git a/src/diffusers/schedulers/scheduling_edm_euler.py b/src/diffusers/schedulers/scheduling_edm_euler.py index b37e6e0fd74a..121ac0d174f6 100644 --- a/src/diffusers/schedulers/scheduling_edm_euler.py +++ b/src/diffusers/schedulers/scheduling_edm_euler.py @@ -310,11 +310,7 @@ def step( returned, otherwise a tuple is returned where the first element is the sample tensor. """ - if ( - isinstance(timestep, int) - or isinstance(timestep, torch.IntTensor) - or isinstance(timestep, torch.LongTensor) - ): + if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 085683e56f70..0e5904539a14 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -375,11 +375,7 @@ def step( """ - if ( - isinstance(timestep, int) - or isinstance(timestep, torch.IntTensor) - or isinstance(timestep, torch.LongTensor) - ): + if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 5f9db844ff35..de1f96a073ef 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -530,11 +530,7 @@ def step( returned, otherwise a tuple is returned where the first element is the sample tensor. """ - if ( - isinstance(timestep, int) - or isinstance(timestep, torch.IntTensor) - or isinstance(timestep, torch.LongTensor) - ): + if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index cdc92036613d..04f91d758b94 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -58,20 +58,25 @@ get_objects_from_module, is_accelerate_available, is_accelerate_version, + is_bitsandbytes_available, is_bs4_available, is_flax_available, is_ftfy_available, + is_google_colab, is_inflect_available, is_invisible_watermark_available, is_k_diffusion_available, is_k_diffusion_version, is_librosa_available, is_note_seq_available, + is_notebook, is_onnx_available, is_peft_available, is_peft_version, + is_safetensors_available, is_scipy_available, is_tensorboard_available, + is_timm_available, is_torch_available, is_torch_npu_available, is_torch_version, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index f5f57d8a5c5f..b8ce2d7c0466 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -295,6 +295,39 @@ except importlib_metadata.PackageNotFoundError: _torchvision_available = False +_timm_available = importlib.util.find_spec("timm") is not None +if _timm_available: + try: + _timm_version = importlib_metadata.version("timm") + logger.info(f"Timm version {_timm_version} available.") + except importlib_metadata.PackageNotFoundError: + _timm_available = False + + +def is_timm_available(): + return _timm_available + + +_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None +try: + _bitsandbytes_version = importlib_metadata.version("bitsandbytes") + logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}") +except importlib_metadata.PackageNotFoundError: + _bitsandbytes_available = False + +# Taken from `huggingface_hub`. +_is_notebook = False +try: + shell_class = get_ipython().__class__ # type: ignore # noqa: F821 + for parent_class in shell_class.__mro__: # e.g. "is subclass of" + if parent_class.__name__ == "ZMQInteractiveShell": + _is_notebook = True # Jupyter notebook, Google colab or qtconsole + break +except NameError: + pass # Probably standard Python interpreter + +_is_google_colab = "google.colab" in sys.modules + def is_torch_available(): return _torch_available @@ -392,6 +425,22 @@ def is_torchvision_available(): return _torchvision_available +def is_safetensors_available(): + return _safetensors_available + + +def is_bitsandbytes_available(): + return _bitsandbytes_available + + +def is_notebook(): + return _is_notebook + + +def is_google_colab(): + return _is_google_colab + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -499,6 +548,20 @@ def is_torchvision_available(): {0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=0.2.0` """ +# docstyle-ignore +PEFT_IMPORT_ERROR = """ +{0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install peft` +""" + +# docstyle-ignore +SAFETENSORS_IMPORT_ERROR = """ +{0} requires the safetensors library but it was not found in your environment. You can install it with pip: `pip install safetensors` +""" + +# docstyle-ignore +BITSANDBYTES_IMPORT_ERROR = """ +{0} requires the bitsandbytes library but it was not found in your environment. You can install it with pip: `pip install bitsandbytes` +""" BACKENDS_MAPPING = OrderedDict( [ @@ -520,6 +583,9 @@ def is_torchvision_available(): ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), + ("peft", (is_peft_available, PEFT_IMPORT_ERROR)), + ("safetensors", (is_safetensors_available, SAFETENSORS_IMPORT_ERROR)), + ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ] ) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index c1756d6590d1..8a6afd768428 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -33,6 +33,7 @@ is_onnx_available, is_opencv_available, is_peft_available, + is_timm_available, is_torch_available, is_torch_version, is_torchsde_available, @@ -340,6 +341,13 @@ def require_peft_backend(test_case): return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case) +def require_timm(test_case): + """ + Decorator marking a test that requires timm. These tests are skipped when timm isn't installed. + """ + return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case) + + def require_peft_version_greater(peft_version): """ Decorator marking a test that requires PEFT backend with a specific version, this would require some specific diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index 92cb6b76f195..f6ca4f304eb9 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -224,7 +224,7 @@ def test_sdxl_1_0_blockwise_lora(self): ).images images = images[0, -3:, -3:, -1].flatten() - expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535]) + expected = np.array([00.4468, 0.4061, 0.4134, 0.3637, 0.3202, 0.365, 0.3786, 0.3725, 0.3535]) max_diff = numpy_cosine_similarity_distance(expected, images) assert max_diff < 1e-4 @@ -507,13 +507,12 @@ def test_controlnet_canny_lora(self): image = load_image( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" ) - images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images assert images[0].shape == (768, 512, 3) original_image = images[0, -3:, -3:, -1].flatten() - expected_image = np.array([0.4574, 0.4461, 0.4435, 0.4462, 0.4396, 0.439, 0.4474, 0.4486, 0.4333]) + expected_image = np.array([0.4574, 0.4487, 0.4435, 0.5163, 0.4396, 0.4411, 0.518, 0.4465, 0.4333]) max_diff = numpy_cosine_similarity_distance(expected_image, original_image) assert max_diff < 1e-4 diff --git a/tests/models/autoencoders/test_models_vq.py b/tests/models/autoencoders/test_models_vq.py index d4262e2709dc..c61ae1bdf0ff 100644 --- a/tests/models/autoencoders/test_models_vq.py +++ b/tests/models/autoencoders/test_models_vq.py @@ -98,3 +98,19 @@ def test_output_pretrained(self): expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143]) # fmt: on self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + def test_loss_pretrained(self): + model = VQModel.from_pretrained("fusing/vqgan-dummy") + model.to(torch_device).eval() + + torch.manual_seed(0) + backend_manual_seed(torch_device, 0) + + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + image = image.to(torch_device) + with torch.no_grad(): + output = model(image).commit_loss.cpu() + # fmt: off + expected_output = torch.tensor([0.1936]) + # fmt: on + self.assertTrue(torch.allclose(output, expected_output, atol=1e-3)) diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py b/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py index a7423bebd939..3341f6704e75 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py @@ -19,7 +19,15 @@ import numpy as np import torch from PIL import Image -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import ( + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) from diffusers import ( AutoencoderKL, @@ -34,6 +42,7 @@ from ..pipeline_params import ( IMAGE_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_BATCH_PARAMS, + TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS, ) @@ -55,6 +64,14 @@ class ControlNetPipelineSDXLFastTests( batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = frozenset(IMAGE_TO_IMAGE_IMAGE_PARAMS.union({"mask_image", "control_image"})) image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union( + { + "add_text_embeds", + "add_time_ids", + "mask", + "masked_image_latents", + } + ) def get_dummy_components(self): torch.manual_seed(0) @@ -129,6 +146,30 @@ def get_dummy_components(self): text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + image_encoder_config = CLIPVisionConfig( + hidden_size=32, + image_size=224, + projection_dim=32, + intermediate_size=37, + num_attention_heads=4, + num_channels=3, + num_hidden_layers=5, + patch_size=14, + ) + + image_encoder = CLIPVisionModelWithProjection(image_encoder_config) + + feature_extractor = CLIPImageProcessor( + crop_size=224, + do_center_crop=True, + do_normalize=True, + do_resize=True, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + resample=3, + size=224, + ) + components = { "unet": unet, "controlnet": controlnet, @@ -138,6 +179,8 @@ def get_dummy_components(self): "tokenizer": tokenizer, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, + "image_encoder": image_encoder, + "feature_extractor": feature_extractor, } return components diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py b/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py index 61ff675856ae..766eea4c0c32 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py @@ -34,6 +34,7 @@ IMAGE_TO_IMAGE_IMAGE_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS, + TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, ) from ..test_pipelines_common import ( IPAdapterTesterMixin, @@ -55,9 +56,13 @@ class ControlNetPipelineSDXLImg2ImgFastTests( ): pipeline_class = StableDiffusionXLControlNetImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS + required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union( + {"add_text_embeds", "add_time_ids", "add_neg_time_ids"} + ) def get_dummy_components(self, skip_first_text_encoder=False): torch.manual_seed(0) diff --git a/tests/single_file/test_stable_diffusion_single_file.py b/tests/single_file/test_stable_diffusion_single_file.py index 2f650660d593..99c884fae06b 100644 --- a/tests/single_file/test_stable_diffusion_single_file.py +++ b/tests/single_file/test_stable_diffusion_single_file.py @@ -82,6 +82,7 @@ def test_single_file_legacy_scaling_factor(self): assert pipe.vae.config.scaling_factor == new_scaling_factor +@slow class StableDiffusion21PipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionPipeline ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors" From b6b6dd9626ae2d07d80efeaf9e05e87a2f8c4715 Mon Sep 17 00:00:00 2001 From: neuron-party Date: Thu, 16 May 2024 14:37:23 -0700 Subject: [PATCH 2/2] supporting custom timesteps and sigmas for forward call in inpainting and img2img controlnet sdxl pipelines --- .../pipeline_controlnet_inpaint_sd_xl.py | 74 +++++++++++++++++- .../pipeline_controlnet_sd_xl_img2img.py | 75 ++++++++++++++++++- 2 files changed, 146 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 56c2dc10a570..523cc4434181 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -151,6 +151,66 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + class StableDiffusionXLControlNetInpaintPipeline( DiffusionPipeline, StableDiffusionMixin, @@ -1168,6 +1228,8 @@ def __call__( padding_mask_crop: Optional[int] = None, strength: float = 0.9999, num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, denoising_start: Optional[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, @@ -1243,6 +1305,14 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. denoising_start (`float`, *optional*): When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and @@ -1487,7 +1557,9 @@ def __call__( def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 - self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, strength, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 0059fb1a729f..1286a48a9ede 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -158,6 +158,66 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + class StableDiffusionXLControlNetImg2ImgPipeline( DiffusionPipeline, StableDiffusionMixin, @@ -1085,6 +1145,8 @@ def __call__( width: Optional[int] = None, strength: float = 0.8, num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, @@ -1159,6 +1221,14 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -1437,10 +1507,11 @@ def __call__( assert False # 5. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - self._num_timesteps = len(timesteps) # 6. Prepare latent variables if latents is None: