From 326f326e1781b1fb888611a37795b474fe496dd8 Mon Sep 17 00:00:00 2001 From: Jongwoo Han Date: Tue, 16 May 2023 20:51:10 +0900 Subject: [PATCH 01/33] Replace deprecated command with environment file (#3409) Co-authored-by: Patrick von Platen --- .github/actions/setup-miniconda/action.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/actions/setup-miniconda/action.yml b/.github/actions/setup-miniconda/action.yml index 8a82ae8b17bf..cc755d3aad79 100644 --- a/.github/actions/setup-miniconda/action.yml +++ b/.github/actions/setup-miniconda/action.yml @@ -27,7 +27,7 @@ runs: - name: Get date id: get-date shell: bash - run: echo "::set-output name=today::$(/bin/date -u '+%Y%m%d')d" + run: echo "today=$(/bin/date -u '+%Y%m%d')d" >> $GITHUB_OUTPUT - name: Setup miniconda cache id: miniconda-cache uses: actions/cache@v2 @@ -143,4 +143,4 @@ runs: echo "There is ${AVAIL}KB free space left in $MOUNT, continue" fi fi - done \ No newline at end of file + done From d2285f51589bbee18673272611b709d306e7f911 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 16 May 2023 13:58:24 +0200 Subject: [PATCH 02/33] fix warning message pipeline loading (#3446) --- src/diffusers/pipelines/pipeline_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 9288248d309b..a4d3dd1f1673 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -296,8 +296,7 @@ def maybe_raise_or_warn( if not issubclass(model_cls, expected_class_obj): raise ValueError( - f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" - f" {expected_class_obj}" + f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}" ) else: logger.warning( From 9d44e2fb6600e80f410b2c05139c001fb0fa9794 Mon Sep 17 00:00:00 2001 From: asfiyab-nvidia <117682710+asfiyab-nvidia@users.noreply.github.com> Date: Tue, 16 May 2023 06:28:01 -0700 Subject: [PATCH 03/33] add stable diffusion tensorrt img2img pipeline (#3419) * add stable diffusion tensorrt img2img pipeline Signed-off-by: Asfiya Baig * update docstrings Signed-off-by: Asfiya Baig --------- Signed-off-by: Asfiya Baig --- examples/community/README.md | 44 +- .../stable_diffusion_tensorrt_img2img.py | 1055 +++++++++++++++++ .../stable_diffusion_tensorrt_txt2img.py | 10 +- 3 files changed, 1102 insertions(+), 7 deletions(-) mode change 100644 => 100755 examples/community/README.md create mode 100755 examples/community/stable_diffusion_tensorrt_img2img.py mode change 100644 => 100755 examples/community/stable_diffusion_tensorrt_txt2img.py diff --git a/examples/community/README.md b/examples/community/README.md old mode 100644 new mode 100755 index 3d034b30fcff..47b129ce9e7e --- a/examples/community/README.md +++ b/examples/community/README.md @@ -31,11 +31,10 @@ If a community doesn't work as expected, please open an issue and ping the autho | UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | | DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - | [Aengus (Duc-Anh)](https://github.com/aengusng8) | | CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) | -| TensorRT Stable Diffusion Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) | +| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) | | EDICT Image Editing Pipeline | Diffusion pipeline for text-guided image editing | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline) | - | [Joqsan Azocar](https://github.com/Joqsan) | | Stable Diffusion RePaint | Stable Diffusion pipeline using [RePaint](https://arxiv.org/abs/2201.0986) for inpainting. | [Stable Diffusion RePaint](#stable-diffusion-repaint ) | - | [Markus Pobitzer](https://github.com/Markus-Pobitzer) | - - +| TensorRT Stable Diffusion Image to Image Pipeline | Accelerates the Stable Diffusion Image2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Image to Image Pipeline](#tensorrt-image2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) | To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly. ```py @@ -1282,3 +1281,42 @@ pipe = pipe.to("cuda") prompt = "Face of a yellow cat, high resolution, sitting on a park bench" image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] ``` + +### TensorRT Image2Image Stable Diffusion Pipeline + +The TensorRT Pipeline can be used to accelerate the Image2Image Stable Diffusion Inference run. + +NOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes. + +```python +import requests +from io import BytesIO +from PIL import Image +import torch +from diffusers import DDIMScheduler +from diffusers.pipelines.stable_diffusion import StableDiffusionImg2ImgPipeline + +# Use the DDIMScheduler scheduler here instead +scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", + subfolder="scheduler") + + +pipe = StableDiffusionImg2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", + custom_pipeline="stable_diffusion_tensorrt_img2img", + revision='fp16', + torch_dtype=torch.float16, + scheduler=scheduler,) + +# re-use cached folder to save ONNX models and TensorRT Engines +pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", revision='fp16',) + +pipe = pipe.to("cuda") + +url = "https://pajoca.com/wp-content/uploads/2022/09/tekito-yamakawa-1.png" +response = requests.get(url) +input_image = Image.open(BytesIO(response.content)).convert("RGB") + +prompt = "photorealistic new zealand hills" +image = pipe(prompt, image=input_image, strength=0.75,).images[0] +image.save('tensorrt_img2img_new_zealand_hills.png') +``` diff --git a/examples/community/stable_diffusion_tensorrt_img2img.py b/examples/community/stable_diffusion_tensorrt_img2img.py new file mode 100755 index 000000000000..67c7c2d00fbf --- /dev/null +++ b/examples/community/stable_diffusion_tensorrt_img2img.py @@ -0,0 +1,1055 @@ +# +# Copyright 2023 The HuggingFace Inc. team. +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +from collections import OrderedDict +from copy import copy +from typing import List, Optional, Union + +import numpy as np +import onnx +import onnx_graphsurgeon as gs +import PIL +import tensorrt as trt +import torch +from huggingface_hub import snapshot_download +from onnx import shape_inference +from polygraphy import cuda +from polygraphy.backend.common import bytes_from_path +from polygraphy.backend.onnx.loader import fold_constants +from polygraphy.backend.trt import ( + CreateConfig, + Profile, + engine_from_bytes, + engine_from_network, + network_from_onnx_path, + save_engine, +) +from polygraphy.backend.trt import util as trt_util +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion import ( + StableDiffusionImg2ImgPipeline, + StableDiffusionPipelineOutput, + StableDiffusionSafetyChecker, +) +from diffusers.schedulers import DDIMScheduler +from diffusers.utils import DIFFUSERS_CACHE, logging + + +""" +Installation instructions +python3 -m pip install --upgrade transformers diffusers>=0.16.0 +python3 -m pip install --upgrade tensorrt>=8.6.1 +python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com +python3 -m pip install onnxruntime +""" + +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Map of numpy dtype -> torch dtype +numpy_to_torch_dtype_dict = { + np.uint8: torch.uint8, + np.int8: torch.int8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float64, + np.complex64: torch.complex64, + np.complex128: torch.complex128, +} +if np.version.full_version >= "1.24.0": + numpy_to_torch_dtype_dict[np.bool_] = torch.bool +else: + numpy_to_torch_dtype_dict[np.bool] = torch.bool + +# Map of torch dtype -> numpy dtype +torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} + + +def device_view(t): + return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype]) + + +def preprocess_image(image): + """ + image: torch.Tensor + """ + w, h = image.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h)) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).contiguous() + return 2.0 * image - 1.0 + + +class Engine: + def __init__(self, engine_path): + self.engine_path = engine_path + self.engine = None + self.context = None + self.buffers = OrderedDict() + self.tensors = OrderedDict() + + def __del__(self): + [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)] + del self.engine + del self.context + del self.buffers + del self.tensors + + def build( + self, + onnx_path, + fp16, + input_profile=None, + enable_preview=False, + enable_all_tactics=False, + timing_cache=None, + workspace_size=0, + ): + logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") + p = Profile() + if input_profile: + for name, dims in input_profile.items(): + assert len(dims) == 3 + p.add(name, min=dims[0], opt=dims[1], max=dims[2]) + + config_kwargs = {} + + config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805] + if enable_preview: + # Faster dynamic shapes made optional since it increases engine build time. + config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805) + if workspace_size > 0: + config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size} + if not enable_all_tactics: + config_kwargs["tactic_sources"] = [] + + engine = engine_from_network( + network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]), + config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs), + save_timing_cache=timing_cache, + ) + save_engine(engine, path=self.engine_path) + + def load(self): + logger.warning(f"Loading TensorRT engine: {self.engine_path}") + self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) + + def activate(self): + self.context = self.engine.create_execution_context() + + def allocate_buffers(self, shape_dict=None, device="cuda"): + for idx in range(trt_util.get_bindings_per_profile(self.engine)): + binding = self.engine[idx] + if shape_dict and binding in shape_dict: + shape = shape_dict[binding] + else: + shape = self.engine.get_binding_shape(binding) + dtype = trt.nptype(self.engine.get_binding_dtype(binding)) + if self.engine.binding_is_input(binding): + self.context.set_binding_shape(idx, shape) + tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) + self.tensors[binding] = tensor + self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype) + + def infer(self, feed_dict, stream): + start_binding, end_binding = trt_util.get_active_profile_bindings(self.context) + # shallow copy of ordered dict + device_buffers = copy(self.buffers) + for name, buf in feed_dict.items(): + assert isinstance(buf, cuda.DeviceView) + device_buffers[name] = buf + bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()] + noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr) + if not noerror: + raise ValueError("ERROR: inference failed.") + + return self.tensors + + +class Optimizer: + def __init__(self, onnx_graph): + self.graph = gs.import_onnx(onnx_graph) + + def cleanup(self, return_onnx=False): + self.graph.cleanup().toposort() + if return_onnx: + return gs.export_onnx(self.graph) + + def select_outputs(self, keep, names=None): + self.graph.outputs = [self.graph.outputs[o] for o in keep] + if names: + for i, name in enumerate(names): + self.graph.outputs[i].name = name + + def fold_constants(self, return_onnx=False): + onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) + self.graph = gs.import_onnx(onnx_graph) + if return_onnx: + return onnx_graph + + def infer_shapes(self, return_onnx=False): + onnx_graph = gs.export_onnx(self.graph) + if onnx_graph.ByteSize() > 2147483648: + raise TypeError("ERROR: model size exceeds supported 2GB limit") + else: + onnx_graph = shape_inference.infer_shapes(onnx_graph) + + self.graph = gs.import_onnx(onnx_graph) + if return_onnx: + return onnx_graph + + +class BaseModel: + def __init__(self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77): + self.model = model + self.name = "SD Model" + self.fp16 = fp16 + self.device = device + + self.min_batch = 1 + self.max_batch = max_batch_size + self.min_image_shape = 256 # min image resolution: 256x256 + self.max_image_shape = 1024 # max image resolution: 1024x1024 + self.min_latent_shape = self.min_image_shape // 8 + self.max_latent_shape = self.max_image_shape // 8 + + self.embedding_dim = embedding_dim + self.text_maxlen = text_maxlen + + def get_model(self): + return self.model + + def get_input_names(self): + pass + + def get_output_names(self): + pass + + def get_dynamic_axes(self): + return None + + def get_sample_input(self, batch_size, image_height, image_width): + pass + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + return None + + def get_shape_dict(self, batch_size, image_height, image_width): + return None + + def optimize(self, onnx_graph): + opt = Optimizer(onnx_graph) + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + onnx_opt_graph = opt.cleanup(return_onnx=True) + return onnx_opt_graph + + def check_dims(self, batch_size, image_height, image_width): + assert batch_size >= self.min_batch and batch_size <= self.max_batch + assert image_height % 8 == 0 or image_width % 8 == 0 + latent_height = image_height // 8 + latent_width = image_width // 8 + assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape + assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape + return (latent_height, latent_width) + + def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape): + min_batch = batch_size if static_batch else self.min_batch + max_batch = batch_size if static_batch else self.max_batch + latent_height = image_height // 8 + latent_width = image_width // 8 + min_image_height = image_height if static_shape else self.min_image_shape + max_image_height = image_height if static_shape else self.max_image_shape + min_image_width = image_width if static_shape else self.min_image_shape + max_image_width = image_width if static_shape else self.max_image_shape + min_latent_height = latent_height if static_shape else self.min_latent_shape + max_latent_height = latent_height if static_shape else self.max_latent_shape + min_latent_width = latent_width if static_shape else self.min_latent_shape + max_latent_width = latent_width if static_shape else self.max_latent_shape + return ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) + + +def getOnnxPath(model_name, onnx_dir, opt=True): + return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx") + + +def getEnginePath(model_name, engine_dir): + return os.path.join(engine_dir, model_name + ".plan") + + +def build_engines( + models: dict, + engine_dir, + onnx_dir, + onnx_opset, + opt_image_height, + opt_image_width, + opt_batch_size=1, + force_engine_rebuild=False, + static_batch=False, + static_shape=True, + enable_preview=False, + enable_all_tactics=False, + timing_cache=None, + max_workspace_size=0, +): + built_engines = {} + if not os.path.isdir(onnx_dir): + os.makedirs(onnx_dir) + if not os.path.isdir(engine_dir): + os.makedirs(engine_dir) + + # Export models to ONNX + for model_name, model_obj in models.items(): + engine_path = getEnginePath(model_name, engine_dir) + if force_engine_rebuild or not os.path.exists(engine_path): + logger.warning("Building Engines...") + logger.warning("Engine build can take a while to complete") + onnx_path = getOnnxPath(model_name, onnx_dir, opt=False) + onnx_opt_path = getOnnxPath(model_name, onnx_dir) + if force_engine_rebuild or not os.path.exists(onnx_opt_path): + if force_engine_rebuild or not os.path.exists(onnx_path): + logger.warning(f"Exporting model: {onnx_path}") + model = model_obj.get_model() + with torch.inference_mode(), torch.autocast("cuda"): + inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) + torch.onnx.export( + model, + inputs, + onnx_path, + export_params=True, + opset_version=onnx_opset, + do_constant_folding=True, + input_names=model_obj.get_input_names(), + output_names=model_obj.get_output_names(), + dynamic_axes=model_obj.get_dynamic_axes(), + ) + del model + torch.cuda.empty_cache() + gc.collect() + else: + logger.warning(f"Found cached model: {onnx_path}") + + # Optimize onnx + if force_engine_rebuild or not os.path.exists(onnx_opt_path): + logger.warning(f"Generating optimizing model: {onnx_opt_path}") + onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path)) + onnx.save(onnx_opt_graph, onnx_opt_path) + else: + logger.warning(f"Found cached optimized model: {onnx_opt_path} ") + + # Build TensorRT engines + for model_name, model_obj in models.items(): + engine_path = getEnginePath(model_name, engine_dir) + engine = Engine(engine_path) + onnx_path = getOnnxPath(model_name, onnx_dir, opt=False) + onnx_opt_path = getOnnxPath(model_name, onnx_dir) + + if force_engine_rebuild or not os.path.exists(engine.engine_path): + engine.build( + onnx_opt_path, + fp16=True, + input_profile=model_obj.get_input_profile( + opt_batch_size, + opt_image_height, + opt_image_width, + static_batch=static_batch, + static_shape=static_shape, + ), + enable_preview=enable_preview, + timing_cache=timing_cache, + workspace_size=max_workspace_size, + ) + built_engines[model_name] = engine + + # Load and activate TensorRT engines + for model_name, model_obj in models.items(): + engine = built_engines[model_name] + engine.load() + engine.activate() + + return built_engines + + +def runEngine(engine, feed_dict, stream): + return engine.infer(feed_dict, stream) + + +class CLIP(BaseModel): + def __init__(self, model, device, max_batch_size, embedding_dim): + super(CLIP, self).__init__( + model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim + ) + self.name = "CLIP" + + def get_input_names(self): + return ["input_ids"] + + def get_output_names(self): + return ["text_embeddings", "pooler_output"] + + def get_dynamic_axes(self): + return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + return { + "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return { + "input_ids": (batch_size, self.text_maxlen), + "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), + } + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) + + def optimize(self, onnx_graph): + opt = Optimizer(onnx_graph) + opt.select_outputs([0]) # delete graph output#1 + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + opt.select_outputs([0], names=["text_embeddings"]) # rename network output + opt_onnx_graph = opt.cleanup(return_onnx=True) + return opt_onnx_graph + + +def make_CLIP(model, device, max_batch_size, embedding_dim, inpaint=False): + return CLIP(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) + + +class UNet(BaseModel): + def __init__( + self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77, unet_dim=4 + ): + super(UNet, self).__init__( + model=model, + fp16=fp16, + device=device, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim, + text_maxlen=text_maxlen, + ) + self.unet_dim = unet_dim + self.name = "UNet" + + def get_input_names(self): + return ["sample", "timestep", "encoder_hidden_states"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return { + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + return { + "sample": [ + (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), + (2 * batch_size, self.unet_dim, latent_height, latent_width), + (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), + ], + "encoder_hidden_states": [ + (2 * min_batch, self.text_maxlen, self.embedding_dim), + (2 * batch_size, self.text_maxlen, self.embedding_dim), + (2 * max_batch, self.text_maxlen, self.embedding_dim), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), + "latent": (2 * batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn( + 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + ), + torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + ) + + +def make_UNet(model, device, max_batch_size, embedding_dim, inpaint=False): + return UNet( + model, + fp16=True, + device=device, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim, + unet_dim=(9 if inpaint else 4), + ) + + +class VAE(BaseModel): + def __init__(self, model, device, max_batch_size, embedding_dim): + super(VAE, self).__init__( + model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim + ) + self.name = "VAE decoder" + + def get_input_names(self): + return ["latent"] + + def get_output_names(self): + return ["images"] + + def get_dynamic_axes(self): + return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + return { + "latent": [ + (min_batch, 4, min_latent_height, min_latent_width), + (batch_size, 4, latent_height, latent_width), + (max_batch, 4, max_latent_height, max_latent_width), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "latent": (batch_size, 4, latent_height, latent_width), + "images": (batch_size, 3, image_height, image_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device) + + +def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False): + return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) + + +class TorchVAEEncoder(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.vae_encoder = model + + def forward(self, x): + return self.vae_encoder.encode(x).latent_dist.sample() + + +class VAEEncoder(BaseModel): + def __init__(self, model, device, max_batch_size, embedding_dim): + super(VAEEncoder, self).__init__( + model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim + ) + self.name = "VAE encoder" + + def get_model(self): + vae_encoder = TorchVAEEncoder(self.model) + return vae_encoder + + def get_input_names(self): + return ["images"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return {"images": {0: "B", 2: "8H", 3: "8W"}, "latent": {0: "B", 2: "H", 3: "W"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + assert batch_size >= self.min_batch and batch_size <= self.max_batch + min_batch = batch_size if static_batch else self.min_batch + max_batch = batch_size if static_batch else self.max_batch + self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + _, + _, + _, + _, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + + return { + "images": [ + (min_batch, 3, min_image_height, min_image_width), + (batch_size, 3, image_height, image_width), + (max_batch, 3, max_image_height, max_image_width), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "images": (batch_size, 3, image_height, image_width), + "latent": (batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32, device=self.device) + + +def make_VAEEncoder(model, device, max_batch_size, embedding_dim, inpaint=False): + return VAEEncoder(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) + + +class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): + r""" + Pipeline for image-to-image generation using TensorRT accelerated Stable Diffusion. + + This model inherits from [`StableDiffusionImg2ImgPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + stages=["clip", "unet", "vae", "vae_encoder"], + image_height: int = 512, + image_width: int = 512, + max_batch_size: int = 16, + # ONNX export parameters + onnx_opset: int = 17, + onnx_dir: str = "onnx", + # TensorRT engine build parameters + engine_dir: str = "engine", + build_preview_features: bool = True, + force_engine_rebuild: bool = False, + timing_cache: str = "timing_cache", + ): + super().__init__( + vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker + ) + + self.vae.forward = self.vae.decode + + self.stages = stages + self.image_height, self.image_width = image_height, image_width + self.inpaint = False + self.onnx_opset = onnx_opset + self.onnx_dir = onnx_dir + self.engine_dir = engine_dir + self.force_engine_rebuild = force_engine_rebuild + self.timing_cache = timing_cache + self.build_static_batch = False + self.build_dynamic_shape = False + self.build_preview_features = build_preview_features + + self.max_batch_size = max_batch_size + # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation. + if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512: + self.max_batch_size = 4 + + self.stream = None # loaded in loadResources() + self.models = {} # loaded in __loadModels() + self.engine = {} # loaded in build_engines() + + def __loadModels(self): + # Load pipeline models + self.embedding_dim = self.text_encoder.config.hidden_size + models_args = { + "device": self.torch_device, + "max_batch_size": self.max_batch_size, + "embedding_dim": self.embedding_dim, + "inpaint": self.inpaint, + } + if "clip" in self.stages: + self.models["clip"] = make_CLIP(self.text_encoder, **models_args) + if "unet" in self.stages: + self.models["unet"] = make_UNet(self.unet, **models_args) + if "vae" in self.stages: + self.models["vae"] = make_VAE(self.vae, **models_args) + if "vae_encoder" in self.stages: + self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args) + + @classmethod + def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + + cls.cached_folder = ( + pretrained_model_name_or_path + if os.path.isdir(pretrained_model_name_or_path) + else snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + ) + + def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False): + super().to(torch_device, silence_dtype_warnings=silence_dtype_warnings) + + self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir) + self.engine_dir = os.path.join(self.cached_folder, self.engine_dir) + self.timing_cache = os.path.join(self.cached_folder, self.timing_cache) + + # set device + self.torch_device = self._execution_device + logger.warning(f"Running inference on device: {self.torch_device}") + + # load models + self.__loadModels() + + # build engines + self.engine = build_engines( + self.models, + self.engine_dir, + self.onnx_dir, + self.onnx_opset, + opt_image_height=self.image_height, + opt_image_width=self.image_width, + force_engine_rebuild=self.force_engine_rebuild, + static_batch=self.build_static_batch, + static_shape=not self.build_dynamic_shape, + enable_preview=self.build_preview_features, + timing_cache=self.timing_cache, + ) + + return self + + def __initialize_timesteps(self, timesteps, strength): + self.scheduler.set_timesteps(timesteps) + offset = self.scheduler.steps_offset if hasattr(self.scheduler, "steps_offset") else 0 + init_timestep = int(timesteps * strength) + offset + init_timestep = min(init_timestep, timesteps) + t_start = max(timesteps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.torch_device) + return timesteps, t_start + + def __preprocess_images(self, batch_size, images=()): + init_images = [] + for image in images: + image = image.to(self.torch_device).float() + image = image.repeat(batch_size, 1, 1, 1) + init_images.append(image) + return tuple(init_images) + + def __encode_image(self, init_image): + init_latents = runEngine(self.engine["vae_encoder"], {"images": device_view(init_image)}, self.stream)[ + "latent" + ] + init_latents = 0.18215 * init_latents + return init_latents + + def __encode_prompt(self, prompt, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + """ + # Tokenize prompt + text_input_ids = ( + self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.torch_device) + ) + + text_input_ids_inp = device_view(text_input_ids) + # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt + text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[ + "text_embeddings" + ].clone() + + # Tokenize negative prompt + uncond_input_ids = ( + self.tokenizer( + negative_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.torch_device) + ) + uncond_input_ids_inp = device_view(uncond_input_ids) + uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[ + "text_embeddings" + ] + + # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) + + return text_embeddings + + def __denoise_latent( + self, latents, text_embeddings, timesteps=None, step_offset=0, mask=None, masked_image_latents=None + ): + if not isinstance(timesteps, torch.Tensor): + timesteps = self.scheduler.timesteps + for step_index, timestep in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) + if isinstance(mask, torch.Tensor): + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # Predict the noise residual + timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep + + sample_inp = device_view(latent_model_input) + timestep_inp = device_view(timestep_float) + embeddings_inp = device_view(text_embeddings) + noise_pred = runEngine( + self.engine["unet"], + {"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp}, + self.stream, + )["latent"] + + # Perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample + + latents = 1.0 / 0.18215 * latents + return latents + + def __decode_latent(self, latents): + images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"] + images = (images / 2 + 0.5).clamp(0, 1) + return images.cpu().permute(0, 2, 3, 1).float().numpy() + + def __loadResources(self, image_height, image_width, batch_size): + self.stream = cuda.Stream() + + # Allocate buffers for TensorRT engine bindings + for model_name, obj in self.models.items(): + self.engine[model_name].allocate_buffers( + shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device + ) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + strength: float = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + + """ + self.generator = generator + self.denoising_steps = num_inference_steps + self.guidance_scale = guidance_scale + + # Pre-compute latent input scales and linear multistep coefficients + self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device) + + # Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"Expected prompt to be of type list or str but got {type(prompt)}") + + if negative_prompt is None: + negative_prompt = [""] * batch_size + + if negative_prompt is not None and isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + assert len(prompt) == len(negative_prompt) + + if batch_size > self.max_batch_size: + raise ValueError( + f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4" + ) + + # load resources + self.__loadResources(self.image_height, self.image_width, batch_size) + + with torch.inference_mode(), torch.autocast("cuda"), trt.Runtime(TRT_LOGGER): + # Initialize timesteps + timesteps, t_start = self.__initialize_timesteps(self.denoising_steps, strength) + latent_timestep = timesteps[:1].repeat(batch_size) + + # Pre-process input image + if isinstance(image, PIL.Image.Image): + image = preprocess_image(image) + init_image = self.__preprocess_images(batch_size, (image,))[0] + + # VAE encode init image + init_latents = self.__encode_image(init_image) + + # Add noise to latents using timesteps + noise = torch.randn( + init_latents.shape, generator=self.generator, device=self.torch_device, dtype=torch.float32 + ) + latents = self.scheduler.add_noise(init_latents, noise, latent_timestep) + + # CLIP text encoder + text_embeddings = self.__encode_prompt(prompt, negative_prompt) + + # UNet denoiser + latents = self.__denoise_latent(latents, text_embeddings, timesteps=timesteps, step_offset=t_start) + + # VAE decode latent + images = self.__decode_latent(latents) + + images = self.numpy_to_pil(images) + return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=None) diff --git a/examples/community/stable_diffusion_tensorrt_txt2img.py b/examples/community/stable_diffusion_tensorrt_txt2img.py old mode 100644 new mode 100755 index aa7b5c12313b..b51f3176b958 --- a/examples/community/stable_diffusion_tensorrt_txt2img.py +++ b/examples/community/stable_diffusion_tensorrt_txt2img.py @@ -54,8 +54,9 @@ """ Installation instructions -python3 -m pip install --upgrade tensorrt -python3 -m pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com +python3 -m pip install --upgrade transformers diffusers>=0.16.0 +python3 -m pip install --upgrade tensorrt>=8.6.1 +python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com python3 -m pip install onnxruntime """ @@ -132,7 +133,7 @@ def build( config_kwargs["tactic_sources"] = [] engine = engine_from_network( - network_from_onnx_path(onnx_path), + network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]), config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs), save_timing_cache=timing_cache, ) @@ -633,6 +634,7 @@ def __init__( onnx_dir: str = "onnx", # TensorRT engine build parameters engine_dir: str = "engine", + build_preview_features: bool = True, force_engine_rebuild: bool = False, timing_cache: str = "timing_cache", ): @@ -652,7 +654,7 @@ def __init__( self.timing_cache = timing_cache self.build_static_batch = False self.build_dynamic_shape = False - self.build_preview_features = False + self.build_preview_features = build_preview_features self.max_batch_size = max_batch_size # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation. From 886575ee43c3e7060d74e2feb2018111e0998013 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 16 May 2023 20:07:21 +0200 Subject: [PATCH 04/33] Refactor controlnet and add img2img and inpaint (#3386) * refactor controlnet and add img2img and inpaint * First draft to get pipelines to work * make style * Fix more * Fix more * More tests * Fix more * Make inpainting work * make style and more tests * Apply suggestions from code review * up * make style * Fix imports * Fix more * Fix more * Improve examples * add test * Make sure import is correctly deprecated * Make sure everything works in compile mode * make sure authorship is correctly attributed --- docs/source/en/_toctree.yml | 4 +- .../{stable_diffusion => }/controlnet.mdx | 61 +- docs/source/en/api/pipelines/overview.mdx | 2 +- docs/source/en/index.mdx | 2 +- src/diffusers/__init__.py | 2 + src/diffusers/pipeline_utils.py | 10 + src/diffusers/pipelines/__init__.py | 8 +- .../pipelines/controlnet/__init__.py | 22 + .../pipelines/controlnet/multicontrolnet.py | 66 + .../controlnet/pipeline_controlnet.py | 1035 ++++++++++++++ .../controlnet/pipeline_controlnet_img2img.py | 1113 +++++++++++++++ .../controlnet/pipeline_controlnet_inpaint.py | 1228 +++++++++++++++++ .../controlnet/pipeline_flax_controlnet.py | 537 +++++++ .../pipeline_semantic_stable_diffusion.py | 2 +- .../pipelines/stable_diffusion/__init__.py | 2 - ...peline_flax_stable_diffusion_controlnet.py | 529 +------ .../pipeline_stable_diffusion_controlnet.py | 1102 +-------------- .../dummy_torch_and_transformers_objects.py | 30 + tests/pipelines/controlnet/__init__.py | 0 .../test_controlnet.py} | 10 +- .../controlnet/test_controlnet_img2img.py | 366 +++++ .../controlnet/test_controlnet_inpaint.py | 379 +++++ .../test_flax_controlnet.py} | 2 +- .../test_stable_diffusion_image_variation.py | 5 +- .../test_stable_diffusion_inpaint.py | 5 +- 25 files changed, 4878 insertions(+), 1644 deletions(-) rename docs/source/en/api/pipelines/{stable_diffusion => }/controlnet.mdx (67%) create mode 100644 src/diffusers/pipelines/controlnet/__init__.py create mode 100644 src/diffusers/pipelines/controlnet/multicontrolnet.py create mode 100644 src/diffusers/pipelines/controlnet/pipeline_controlnet.py create mode 100644 src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py create mode 100644 src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py create mode 100644 src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py create mode 100644 tests/pipelines/controlnet/__init__.py rename tests/pipelines/{stable_diffusion/test_stable_diffusion_controlnet.py => controlnet/test_controlnet.py} (98%) create mode 100644 tests/pipelines/controlnet/test_controlnet_img2img.py create mode 100644 tests/pipelines/controlnet/test_controlnet_inpaint.py rename tests/pipelines/{stable_diffusion/test_stable_diffusion_flax_controlnet.py => controlnet/test_flax_controlnet.py} (98%) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 246b467d8b04..52d8988206f1 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -148,6 +148,8 @@ title: Audio Diffusion - local: api/pipelines/audioldm title: AudioLDM + - local: api/pipelines/controlnet + title: ControlNet - local: api/pipelines/cycle_diffusion title: Cycle Diffusion - local: api/pipelines/dance_diffusion @@ -203,8 +205,6 @@ title: Self-Attention Guidance - local: api/pipelines/stable_diffusion/panorama title: MultiDiffusion Panorama - - local: api/pipelines/stable_diffusion/controlnet - title: Text-to-Image Generation with ControlNet Conditioning - local: api/pipelines/stable_diffusion/model_editing title: Text-to-Image Model Editing - local: api/pipelines/stable_diffusion/diffedit diff --git a/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx b/docs/source/en/api/pipelines/controlnet.mdx similarity index 67% rename from docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx rename to docs/source/en/api/pipelines/controlnet.mdx index fd5c87821c01..f9e4c3c47e3e 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx +++ b/docs/source/en/api/pipelines/controlnet.mdx @@ -22,7 +22,7 @@ The abstract of the paper is the following: *We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.* -This model was contributed by the amazing community contributor [takuma104](https://huggingface.co/takuma104) ❤️ . +This model was contributed by the community contributor [takuma104](https://huggingface.co/takuma104) ❤️ . Resources: @@ -33,7 +33,9 @@ Resources: | Pipeline | Tasks | Demo |---|---|:---:| -| [StableDiffusionControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py) | *Text-to-Image Generation with ControlNet Conditioning* | [Colab Example](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb) +| [StableDiffusionControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet.py) | *Text-to-Image Generation with ControlNet Conditioning* | [Colab Example](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb) +| [StableDiffusionControlNetImg2ImgPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py) | *Image-to-Image Generation with ControlNet Conditioning* | +| [StableDiffusionControlNetInpaintPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_controlnet_inpaint.py) | *Inpainting Generation with ControlNet Conditioning* | ## Usage example @@ -301,21 +303,22 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h ### ControlNet v1.1 -| Model Name | Control Image Overview| Control Image Example | Generated Image Example | -|---|---|---|---| -|[lllyasviel/control_v11p_sd15_canny](https://huggingface.co/lllyasviel/control_v11p_sd15_canny)
*Trained with canny edge detection* | A monochrome image with white edges on a black background.||| -|[lllyasviel/control_v11e_sd15_ip2p](https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p)
*Trained with pixel to pixel instruction* | No condition .||| -|[lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint)
Trained with image inpainting | No condition.||| -|[lllyasviel/control_v11p_sd15_mlsd](https://huggingface.co/lllyasviel/control_v11p_sd15_mlsd)
Trained with multi-level line segment detection | An image with annotated line segments.||| -|[lllyasviel/control_v11f1p_sd15_depth](https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth)
Trained with depth estimation | An image with depth information, usually represented as a grayscale image.||| -|[lllyasviel/control_v11p_sd15_normalbae](https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae)
Trained with surface normal estimation | An image with surface normal information, usually represented as a color-coded image.||| -|[lllyasviel/control_v11p_sd15_seg](https://huggingface.co/lllyasviel/control_v11p_sd15_seg)
Trained with image segmentation | An image with segmented regions, usually represented as a color-coded image.||| -|[lllyasviel/control_v11p_sd15_lineart](https://huggingface.co/lllyasviel/control_v11p_sd15_lineart)
Trained with line art generation | An image with line art, usually black lines on a white background.||| -|[lllyasviel/control_v11p_sd15s2_lineart_anime](https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime)
Trained with anime line art generation | An image with anime-style line art.||| -|[lllyasviel/control_v11p_sd15_openpose](https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime)
Trained with human pose estimation | An image with human poses, usually represented as a set of keypoints or skeletons.||| -|[lllyasviel/control_v11p_sd15_scribble](https://huggingface.co/lllyasviel/control_v11p_sd15_scribble)
Trained with scribble-based image generation | An image with scribbles, usually random or user-drawn strokes.||| -|[lllyasviel/control_v11p_sd15_softedge](https://huggingface.co/lllyasviel/control_v11p_sd15_softedge)
Trained with soft edge image generation | An image with soft edges, usually to create a more painterly or artistic effect.||| -|[lllyasviel/control_v11e_sd15_shuffle](https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle)
Trained with image shuffling | An image with shuffled patches or regions.||| +| Model Name | Control Image Overview| Condition Image | Control Image Example | Generated Image Example | +|---|---|---|---|---| +|[lllyasviel/control_v11p_sd15_canny](https://huggingface.co/lllyasviel/control_v11p_sd15_canny)
| *Trained with canny edge detection* | A monochrome image with white edges on a black background.||| +|[lllyasviel/control_v11e_sd15_ip2p](https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p)
| *Trained with pixel to pixel instruction* | No condition .||| +|[lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint)
| Trained with image inpainting | No condition.||| +|[lllyasviel/control_v11p_sd15_mlsd](https://huggingface.co/lllyasviel/control_v11p_sd15_mlsd)
| Trained with multi-level line segment detection | An image with annotated line segments.||| +|[lllyasviel/control_v11f1p_sd15_depth](https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth)
| Trained with depth estimation | An image with depth information, usually represented as a grayscale image.||| +|[lllyasviel/control_v11p_sd15_normalbae](https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae)
| Trained with surface normal estimation | An image with surface normal information, usually represented as a color-coded image.||| +|[lllyasviel/control_v11p_sd15_seg](https://huggingface.co/lllyasviel/control_v11p_sd15_seg)
| Trained with image segmentation | An image with segmented regions, usually represented as a color-coded image.||| +|[lllyasviel/control_v11p_sd15_lineart](https://huggingface.co/lllyasviel/control_v11p_sd15_lineart)
| Trained with line art generation | An image with line art, usually black lines on a white background.||| +|[lllyasviel/control_v11p_sd15s2_lineart_anime](https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime)
| Trained with anime line art generation | An image with anime-style line art.||| +|[lllyasviel/control_v11p_sd15_openpose](https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime)
| Trained with human pose estimation | An image with human poses, usually represented as a set of keypoints or skeletons.||| +|[lllyasviel/control_v11p_sd15_scribble](https://huggingface.co/lllyasviel/control_v11p_sd15_scribble)
| Trained with scribble-based image generation | An image with scribbles, usually random or user-drawn strokes.||| +|[lllyasviel/control_v11p_sd15_softedge](https://huggingface.co/lllyasviel/control_v11p_sd15_softedge)
| Trained with soft edge image generation | An image with soft edges, usually to create a more painterly or artistic effect.||| +|[lllyasviel/control_v11e_sd15_shuffle](https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle)
| Trained with image shuffling | An image with shuffled patches or regions.||| +|[lllyasviel/control_v11f1e_sd15_tile](https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile)
| Trained with image tiling | A blurry image or part of an image .||| ## StableDiffusionControlNetPipeline [[autodoc]] StableDiffusionControlNetPipeline @@ -329,6 +332,30 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h - disable_xformers_memory_efficient_attention - load_textual_inversion +## StableDiffusionControlNetImg2ImgPipeline +[[autodoc]] StableDiffusionControlNetImg2ImgPipeline + - all + - __call__ + - enable_attention_slicing + - disable_attention_slicing + - enable_vae_slicing + - disable_vae_slicing + - enable_xformers_memory_efficient_attention + - disable_xformers_memory_efficient_attention + - load_textual_inversion + +## StableDiffusionControlNetInpaintPipeline +[[autodoc]] StableDiffusionControlNetInpaintPipeline + - all + - __call__ + - enable_attention_slicing + - disable_attention_slicing + - enable_vae_slicing + - disable_vae_slicing + - enable_xformers_memory_efficient_attention + - disable_xformers_memory_efficient_attention + - load_textual_inversion + ## FlaxStableDiffusionControlNetPipeline [[autodoc]] FlaxStableDiffusionControlNetPipeline - all diff --git a/docs/source/en/api/pipelines/overview.mdx b/docs/source/en/api/pipelines/overview.mdx index 91716784f8fe..2b2f95590016 100644 --- a/docs/source/en/api/pipelines/overview.mdx +++ b/docs/source/en/api/pipelines/overview.mdx @@ -46,7 +46,7 @@ available a colab notebook to directly try them out. |---|---|:---:|:---:| | [alt_diffusion](./alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | - | [audio_diffusion](./audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio_diffusion.git) | Unconditional Audio Generation | -| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb) +| [controlnet](./api/pipelines/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb) | [cycle_diffusion](./cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation | | [dance_diffusion](./dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | | [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 46a985ac2f8d..66548663827a 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -53,7 +53,7 @@ The library has three main components: |---|---|:---:| | [alt_diffusion](./api/pipelines/alt_diffusion) | [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | | [audio_diffusion](./api/pipelines/audio_diffusion) | [Audio Diffusion](https://github.com/teticio/audio-diffusion.git) | Unconditional Audio Generation | -| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | +| [controlnet](./api/pipelines/controlnet) | [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | | [cycle_diffusion](./api/pipelines/cycle_diffusion) | [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation | | [dance_diffusion](./api/pipelines/dance_diffusion) | [Dance Diffusion](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | | [ddpm](./api/pipelines/ddpm) | [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a8293ea77fef..0d48a16b6216 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -132,6 +132,8 @@ PaintByExamplePipeline, SemanticStableDiffusionPipeline, StableDiffusionAttendAndExcitePipeline, + StableDiffusionControlNetImg2ImgPipeline, + StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetPipeline, StableDiffusionDepth2ImgPipeline, StableDiffusionDiffEditPipeline, diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 5c0c2337dc04..87709d5f616c 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -17,3 +17,13 @@ # It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401 +from .utils import deprecate + + +deprecate( + "pipelines_utils", + "0.22.0", + "Importing `DiffusionPipeline` or `ImagePipelineOutput` from diffusers.pipeline_utils is deprecated. Please import from diffusers.pipelines.pipeline_utils instead.", + standard_warn=False, + stacklevel=3, +) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3cddad4a6b26..9b44f4e5eb14 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -44,6 +44,11 @@ else: from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline from .audioldm import AudioLDMPipeline + from .controlnet import ( + StableDiffusionControlNetImg2ImgPipeline, + StableDiffusionControlNetInpaintPipeline, + StableDiffusionControlNetPipeline, + ) from .deepfloyd_if import ( IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, @@ -58,7 +63,6 @@ from .stable_diffusion import ( CycleDiffusionPipeline, StableDiffusionAttendAndExcitePipeline, - StableDiffusionControlNetPipeline, StableDiffusionDepth2ImgPipeline, StableDiffusionDiffEditPipeline, StableDiffusionImageVariationPipeline, @@ -133,8 +137,8 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_flax_and_transformers_objects import * # noqa F403 else: + from .controlnet import FlaxStableDiffusionControlNetPipeline from .stable_diffusion import ( - FlaxStableDiffusionControlNetPipeline, FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionPipeline, diff --git a/src/diffusers/pipelines/controlnet/__init__.py b/src/diffusers/pipelines/controlnet/__init__.py new file mode 100644 index 000000000000..76ab63bdb116 --- /dev/null +++ b/src/diffusers/pipelines/controlnet/__init__.py @@ -0,0 +1,22 @@ +from ...utils import ( + OptionalDependencyNotAvailable, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 +else: + from .multicontrolnet import MultiControlNetModel + from .pipeline_controlnet import StableDiffusionControlNetPipeline + from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline + from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline + + +if is_transformers_available() and is_flax_available(): + from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline diff --git a/src/diffusers/pipelines/controlnet/multicontrolnet.py b/src/diffusers/pipelines/controlnet/multicontrolnet.py new file mode 100644 index 000000000000..91d40b20124c --- /dev/null +++ b/src/diffusers/pipelines/controlnet/multicontrolnet.py @@ -0,0 +1,66 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn + +from ...models.controlnet import ControlNetModel, ControlNetOutput +from ...models.modeling_utils import ModelMixin + + +class MultiControlNetModel(ModelMixin): + r""" + Multiple `ControlNetModel` wrapper class for Multi-ControlNet + + This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be + compatible with `ControlNetModel`. + + Args: + controlnets (`List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `ControlNetModel` as a list. + """ + + def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]): + super().__init__() + self.nets = nn.ModuleList(controlnets) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: List[torch.tensor], + conditioning_scale: List[float], + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple]: + for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): + down_samples, mid_sample = controlnet( + sample, + timestep, + encoder_hidden_states, + image, + scale, + class_labels, + timestep_cond, + attention_mask, + cross_attention_kwargs, + guess_mode, + return_dict, + ) + + # merge samples + if i == 0: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample + + return down_block_res_samples, mid_block_res_sample diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py new file mode 100644 index 000000000000..8a2ffbbff171 --- /dev/null +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -0,0 +1,1035 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +import os +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + is_accelerate_available, + is_accelerate_version, + is_compiled_module, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" + ... ) + >>> image = np.array(image) + + >>> # get canny image + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # load control net and stable diffusion v1-5 + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> # remove following line if xformers is not installed + >>> pipe.enable_xformers_memory_efficient_attention() + + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image + ... ).images[0] + ``` +""" + + +class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + # the safety checker can offload the vae again + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # control net hook has be manually offloaded as it alternates with unet + cpu_offload_with_hook(self.controlnet, device) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + "For multiple controlnets: `image` must have the same length as the number of controlnets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + + if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: + raise TypeError( + "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" + ) + + if image_is_pil: + image_batch_size = 1 + elif image_is_tensor: + image_batch_size = image.shape[0] + elif image_is_pil_list: + image_batch_size = len(image) + elif image_is_tensor_list: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + images = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) + image_ = np.array(image_) + image_ = image_[None, :] + images.append(image_) + + image = images + + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + + height = (height // 8) * 8 # round down to nearest multiple of 8 + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + + width = (width // 8) * 8 # round down to nearest multiple of 8 + + return height, width + + # override DiffusionPipeline + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + safe_serialization: bool = False, + variant: Optional[str] = None, + ): + if isinstance(self.controlnet, ControlNetModel): + super().save_pretrained(save_directory, safe_serialization, variant) + else: + raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, + `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, image) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + controlnet_latent_model_input = latents + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + controlnet_latent_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + down_block_res_samples, mid_block_res_sample = self.controlnet( + controlnet_latent_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py new file mode 100644 index 000000000000..cb5492790353 --- /dev/null +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -0,0 +1,1113 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +import os +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + deprecate, + is_accelerate_available, + is_accelerate_version, + is_compiled_module, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" + ... ) + >>> np_image = np.array(image) + + >>> # get canny image + >>> np_image = cv2.Canny(np_image, 100, 200) + >>> np_image = np_image[:, :, None] + >>> np_image = np.concatenate([np_image, np_image, np_image], axis=2) + >>> canny_image = Image.fromarray(np_image) + + >>> # load control net and stable diffusion v1-5 + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "futuristic-looking woman", + ... num_inference_steps=20, + ... generator=generator, + ... image=image, + ... control_image=canny_image, + ... ).images[0] + ``` +""" + + +def prepare_image(image): + if isinstance(image, torch.Tensor): + # Batch single image + if image.ndim == 3: + image = image.unsqueeze(0) + + image = image.to(dtype=torch.float32) + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + return image + + +class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + # the safety checker can offload the vae again + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # control net hook has be manually offloaded as it alternates with unet + cpu_offload_with_hook(self.controlnet, device) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + "For multiple controlnets: `image` must have the same length as the number of controlnets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + + if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: + raise TypeError( + "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" + ) + + if image_is_pil: + image_batch_size = 1 + elif image_is_tensor: + image_batch_size = image.shape[0] + elif image_is_pil_list: + image_batch_size = len(image) + elif image_is_tensor_list: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + images = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) + image_ = np.array(image_) + image_ = image_[None, :] + images.append(image_) + + image = images + + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + + height = (height // 8) * 8 # round down to nearest multiple of 8 + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + + width = (width // 8) * 8 # round down to nearest multiple of 8 + + return height, width + + # override DiffusionPipeline + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + safe_serialization: bool = False, + variant: Optional[str] = None, + ): + if isinstance(self.controlnet, ControlNetModel): + super().save_pretrained(save_directory, safe_serialization, variant) + else: + raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, + control_image: Union[ + torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image] + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 0.8, + guess_mode: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, + `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting + than for [`~StableDiffusionControlNetPipeline.__call__`]. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, image) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + control_image, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + # 4. Prepare image, and controlnet_conditioning_image + image = prepare_image(image) + + # 5. Prepare image + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + controlnet_latent_model_input = latents + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + controlnet_latent_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + down_block_res_samples, mid_block_res_sample = self.controlnet( + controlnet_latent_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=controlnet_conditioning_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py new file mode 100644 index 000000000000..a146a1cc2908 --- /dev/null +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -0,0 +1,1228 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/ + +import inspect +import os +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + is_accelerate_available, + is_accelerate_version, + is_compiled_module, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = load_image(img_url).resize((512, 512)) + >>> mask_image = load_image(mask_url).resize((512, 512)) + + >>> image = np.array(init_image) + + >>> # get canny image + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # load control net and stable diffusion inpainting + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( + ... "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "spiderman", + ... num_inference_steps=30, + ... generator=generator, + ... image=init_image, + ... mask_image=mask_image, + ... control_image=canny_image, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image +def prepare_mask_and_masked_image(image, mask, height, width): + """ + Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * (mask < 0.5) + + return mask, masked_image + + +class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + # the safety checker can offload the vae again + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # control net hook has be manually offloaded as it alternates with unet + cpu_offload_with_hook(self.controlnet, device) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + "For multiple controlnets: `image` must have the same length as the number of controlnets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + + if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: + raise TypeError( + "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" + ) + + if image_is_pil: + image_batch_size = 1 + elif image_is_tensor: + image_batch_size = image.shape[0] + elif image_is_pil_list: + image_batch_size = len(image) + elif image_is_tensor_list: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + images = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) + image_ = np.array(image_) + image_ = image_[None, :] + images.append(image_) + + image = images + + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + + height = (height // 8) * 8 # round down to nearest multiple of 8 + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + + width = (width // 8) * 8 # round down to nearest multiple of 8 + + return height, width + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + # encode the mask image into latents space so we can concatenate it to the latents + if isinstance(generator, list): + masked_image_latents = [ + self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(batch_size) + ] + masked_image_latents = torch.cat(masked_image_latents, dim=0) + else: + masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) + masked_image_latents = self.vae.config.scaling_factor * masked_image_latents + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # override DiffusionPipeline + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + safe_serialization: bool = False, + variant: Optional[str] = None, + ): + if isinstance(self.controlnet, ControlNetModel): + super().save_pretrained(save_directory, safe_serialization, variant) + else: + raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.Tensor, PIL.Image.Image] = None, + mask_image: Union[torch.Tensor, PIL.Image.Image] = None, + control_image: Union[ + torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image] + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 0.5, + guess_mode: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, + `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 0.5): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting + than for [`~StableDiffusionControlNetPipeline.__call__`]. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, image) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + control_image, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + # 4. Preprocess mask and image - resizes image and mask w.r.t height and width + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare mask latent variables + mask, masked_image = prepare_mask_and_masked_image(image, mask_image, height, width) + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + controlnet_latent_model_input = latents + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + controlnet_latent_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + down_block_res_samples, mid_block_res_sample = self.controlnet( + controlnet_latent_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=controlnet_conditioning_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py new file mode 100644 index 000000000000..6003fc96b0ad --- /dev/null +++ b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py @@ -0,0 +1,537 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial +from typing import Dict, List, Optional, Union + +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate +from flax.training.common_utils import shard +from PIL import Image +from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from ..stable_diffusion import FlaxStableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker_flax import FlaxStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import jax + >>> import numpy as np + >>> import jax.numpy as jnp + >>> from flax.jax_utils import replicate + >>> from flax.training.common_utils import shard + >>> from diffusers.utils import load_image + >>> from PIL import Image + >>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel + + + >>> def image_grid(imgs, rows, cols): + ... w, h = imgs[0].size + ... grid = Image.new("RGB", size=(cols * w, rows * h)) + ... for i, img in enumerate(imgs): + ... grid.paste(img, box=(i % cols * w, i // cols * h)) + ... return grid + + + >>> def create_key(seed=0): + ... return jax.random.PRNGKey(seed) + + + >>> rng = create_key(0) + + >>> # get canny image + >>> canny_image = load_image( + ... "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg" + ... ) + + >>> prompts = "best quality, extremely detailed" + >>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality" + + >>> # load control net and stable diffusion v1-5 + >>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( + ... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32 + ... ) + >>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32 + ... ) + >>> params["controlnet"] = controlnet_params + + >>> num_samples = jax.device_count() + >>> rng = jax.random.split(rng, jax.device_count()) + + >>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) + >>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples) + >>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples) + + >>> p_params = replicate(params) + >>> prompt_ids = shard(prompt_ids) + >>> negative_prompt_ids = shard(negative_prompt_ids) + >>> processed_image = shard(processed_image) + + >>> output = pipe( + ... prompt_ids=prompt_ids, + ... image=processed_image, + ... params=p_params, + ... prng_seed=rng, + ... num_inference_steps=50, + ... neg_prompt_ids=negative_prompt_ids, + ... jit=True, + ... ).images + + >>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) + >>> output_images = image_grid(output_images, num_samples // 4, 4) + >>> output_images.save("generated_image.png") + ``` +""" + + +class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet Guidance. + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`FlaxCLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`FlaxControlNetModel`]: + Provides additional conditioning to the unet during the denoising process. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or + [`FlaxDPMSolverMultistepScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + controlnet: FlaxControlNetModel, + scheduler: Union[ + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler + ], + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + if safety_checker is None: + logger.warn( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def prepare_text_inputs(self, prompt: Union[str, List[str]]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + + return text_input.input_ids + + def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]): + if not isinstance(image, (Image.Image, list)): + raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") + + if isinstance(image, Image.Image): + image = [image] + + processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) + + return processed_images + + def _get_has_nsfw_concepts(self, features, params): + has_nsfw_concepts = self.safety_checker(features, params) + return has_nsfw_concepts + + def _run_safety_checker(self, images, safety_model_params, jit=False): + # safety_model_params should already be replicated when jit is True + pil_images = [Image.fromarray(image) for image in images] + features = self.feature_extractor(pil_images, return_tensors="np").pixel_values + + if jit: + features = shard(features) + has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) + has_nsfw_concepts = unshard(has_nsfw_concepts) + safety_model_params = unreplicate(safety_model_params) + else: + has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) + + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + def _generate( + self, + prompt_ids: jnp.array, + image: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + num_inference_steps: int, + guidance_scale: float, + latents: Optional[jnp.array] = None, + neg_prompt_ids: Optional[jnp.array] = None, + controlnet_conditioning_scale: float = 1.0, + ): + height, width = image.shape[-2:] + if height % 64 != 0 or width % 64 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.") + + # get prompt text embeddings + prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] + context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) + + image = jnp.concatenate([image] * 2) + + latents_shape = ( + batch_size, + self.unet.config.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + + down_block_res_samples, mid_block_res_sample = self.controlnet.apply( + {"params": params["controlnet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + return_dict=False, + ) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape + ) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + if DEBUG: + # run with python for loop + for i in range(num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + + # scale and decode the image latents with vae + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt_ids: jnp.array, + image: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + num_inference_steps: int = 50, + guidance_scale: Union[float, jnp.array] = 7.5, + latents: jnp.array = None, + neg_prompt_ids: jnp.array = None, + controlnet_conditioning_scale: Union[float, jnp.array] = 1.0, + return_dict: bool = True, + jit: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt_ids (`jnp.array`): + The prompt or prompts to guide the image generation. + image (`jnp.array`): + Array representing the ControlNet input condition. ControlNet use this input condition to generate + guidance to Unet. + params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights + prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + latents (`jnp.array`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + jit (`bool`, defaults to `False`): + Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument + exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + + height, width = image.shape[-2:] + + if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale[:, None] + + if isinstance(controlnet_conditioning_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + controlnet_conditioning_scale = jnp.array([controlnet_conditioning_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + controlnet_conditioning_scale = controlnet_conditioning_scale[:, None] + + if jit: + images = _p_generate( + self, + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + else: + images = self._generate( + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + + if self.safety_checker is not None: + safety_params = params["safety_checker"] + images_uint8_casted = (images * 255).round().astype("uint8") + num_devices, batch_size = images.shape[:2] + + images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) + images = np.asarray(images) + + # block images + if any(has_nsfw_concept): + for i, is_nsfw in enumerate(has_nsfw_concept): + if is_nsfw: + images[i] = np.asarray(images_uint8_casted[i]) + + images = images.reshape(num_devices, batch_size, height, width, 3) + else: + images = np.asarray(images) + has_nsfw_concept = False + + if not return_dict: + return (images, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +# Static argnums are pipe, num_inference_steps. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0), + static_broadcasted_argnums=(0, 5), +) +def _p_generate( + pipe, + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, +): + return pipe._generate( + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + + +@partial(jax.pmap, static_broadcasted_argnums=(0,)) +def _p_get_has_nsfw_concepts(pipe, features, params): + return pipe._get_has_nsfw_concepts(features, params) + + +def unshard(x: jnp.ndarray): + # einops.rearrange(x, 'd b ... -> (d b) ...') + num_devices, batch_size = x.shape[:2] + rest = x.shape[2:] + return x.reshape(num_devices * batch_size, *rest) + + +def preprocess(image, dtype): + image = image.convert("RGB") + w, h = image.size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = jnp.array(image).astype(dtype) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return image diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index e3fe20e196d8..911a5018de18 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -8,10 +8,10 @@ from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...schedulers import KarrasDiffusionSchedulers from ...utils import logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline from . import SemanticStableDiffusionPipelineOutput diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index b89dde319cb3..f39ae67a9aff 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -45,7 +45,6 @@ class StableDiffusionPipelineOutput(BaseOutput): from .pipeline_cycle_diffusion import CycleDiffusionPipeline from .pipeline_stable_diffusion import StableDiffusionPipeline from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline - from .pipeline_stable_diffusion_controlnet import StableDiffusionControlNetPipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy @@ -130,7 +129,6 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput): from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline - from .pipeline_flax_stable_diffusion_controlnet import FlaxStableDiffusionControlNetPipeline from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline from .safety_checker_flax import FlaxStableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py index 7035242a0cda..bec2424ece4d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py @@ -12,526 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings -from functools import partial -from typing import Dict, List, Optional, Union +# NOTE: This file is deprecated and will be removed in a future version. +# It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict -from flax.jax_utils import unreplicate -from flax.training.common_utils import shard -from PIL import Image -from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel +from ...utils import deprecate +from ..controlnet.pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline # noqa: F401 -from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel -from ...schedulers import ( - FlaxDDIMScheduler, - FlaxDPMSolverMultistepScheduler, - FlaxLMSDiscreteScheduler, - FlaxPNDMScheduler, -) -from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring -from ..pipeline_flax_utils import FlaxDiffusionPipeline -from . import FlaxStableDiffusionPipelineOutput -from .safety_checker_flax import FlaxStableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -# Set to True to use python for loop instead of jax.fori_loop for easier debugging -DEBUG = False - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import jax - >>> import numpy as np - >>> import jax.numpy as jnp - >>> from flax.jax_utils import replicate - >>> from flax.training.common_utils import shard - >>> from diffusers.utils import load_image - >>> from PIL import Image - >>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel - - - >>> def image_grid(imgs, rows, cols): - ... w, h = imgs[0].size - ... grid = Image.new("RGB", size=(cols * w, rows * h)) - ... for i, img in enumerate(imgs): - ... grid.paste(img, box=(i % cols * w, i // cols * h)) - ... return grid - - - >>> def create_key(seed=0): - ... return jax.random.PRNGKey(seed) - - - >>> rng = create_key(0) - - >>> # get canny image - >>> canny_image = load_image( - ... "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg" - ... ) - - >>> prompts = "best quality, extremely detailed" - >>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality" - - >>> # load control net and stable diffusion v1-5 - >>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( - ... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32 - ... ) - >>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( - ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32 - ... ) - >>> params["controlnet"] = controlnet_params - - >>> num_samples = jax.device_count() - >>> rng = jax.random.split(rng, jax.device_count()) - - >>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) - >>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples) - >>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples) - - >>> p_params = replicate(params) - >>> prompt_ids = shard(prompt_ids) - >>> negative_prompt_ids = shard(negative_prompt_ids) - >>> processed_image = shard(processed_image) - - >>> output = pipe( - ... prompt_ids=prompt_ids, - ... image=processed_image, - ... params=p_params, - ... prng_seed=rng, - ... num_inference_steps=50, - ... neg_prompt_ids=negative_prompt_ids, - ... jit=True, - ... ).images - - >>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) - >>> output_images = image_grid(output_images, num_samples // 4, 4) - >>> output_images.save("generated_image.png") - ``` -""" - - -class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): - r""" - Pipeline for text-to-image generation using Stable Diffusion with ControlNet Guidance. - - This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`FlaxAutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`FlaxCLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), - specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - controlnet ([`FlaxControlNetModel`]: - Provides additional conditioning to the unet during the denoising process. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or - [`FlaxDPMSolverMultistepScheduler`]. - safety_checker ([`FlaxStableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - - def __init__( - self, - vae: FlaxAutoencoderKL, - text_encoder: FlaxCLIPTextModel, - tokenizer: CLIPTokenizer, - unet: FlaxUNet2DConditionModel, - controlnet: FlaxControlNetModel, - scheduler: Union[ - FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler - ], - safety_checker: FlaxStableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - dtype: jnp.dtype = jnp.float32, - ): - super().__init__() - self.dtype = dtype - - if safety_checker is None: - logger.warn( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - controlnet=controlnet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - - def prepare_text_inputs(self, prompt: Union[str, List[str]]): - if not isinstance(prompt, (str, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - text_input = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - - return text_input.input_ids - - def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]): - if not isinstance(image, (Image.Image, list)): - raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") - - if isinstance(image, Image.Image): - image = [image] - - processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) - - return processed_images - - def _get_has_nsfw_concepts(self, features, params): - has_nsfw_concepts = self.safety_checker(features, params) - return has_nsfw_concepts - - def _run_safety_checker(self, images, safety_model_params, jit=False): - # safety_model_params should already be replicated when jit is True - pil_images = [Image.fromarray(image) for image in images] - features = self.feature_extractor(pil_images, return_tensors="np").pixel_values - - if jit: - features = shard(features) - has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) - has_nsfw_concepts = unshard(has_nsfw_concepts) - safety_model_params = unreplicate(safety_model_params) - else: - has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) - - images_was_copied = False - for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): - if has_nsfw_concept: - if not images_was_copied: - images_was_copied = True - images = images.copy() - - images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image - - if any(has_nsfw_concepts): - warnings.warn( - "Potential NSFW content was detected in one or more images. A black image will be returned" - " instead. Try again with a different prompt and/or seed." - ) - - return images, has_nsfw_concepts - def _generate( - self, - prompt_ids: jnp.array, - image: jnp.array, - params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, - num_inference_steps: int, - guidance_scale: float, - latents: Optional[jnp.array] = None, - neg_prompt_ids: Optional[jnp.array] = None, - controlnet_conditioning_scale: float = 1.0, - ): - height, width = image.shape[-2:] - if height % 64 != 0 or width % 64 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.") - - # get prompt text embeddings - prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] - - # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` - # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` - batch_size = prompt_ids.shape[0] - - max_length = prompt_ids.shape[-1] - - if neg_prompt_ids is None: - uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" - ).input_ids - else: - uncond_input = neg_prompt_ids - negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] - context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) - - image = jnp.concatenate([image] * 2) - - latents_shape = ( - batch_size, - self.unet.config.in_channels, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - if latents is None: - latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - - def loop_body(step, args): - latents, scheduler_state = args - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - latents_input = jnp.concatenate([latents] * 2) - - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - timestep = jnp.broadcast_to(t, latents_input.shape[0]) - - latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) - - down_block_res_samples, mid_block_res_sample = self.controlnet.apply( - {"params": params["controlnet"]}, - jnp.array(latents_input), - jnp.array(timestep, dtype=jnp.int32), - encoder_hidden_states=context, - controlnet_cond=image, - conditioning_scale=controlnet_conditioning_scale, - return_dict=False, - ) - - # predict the noise residual - noise_pred = self.unet.apply( - {"params": params["unet"]}, - jnp.array(latents_input), - jnp.array(timestep, dtype=jnp.int32), - encoder_hidden_states=context, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - ).sample - - # perform guidance - noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) - noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents, scheduler_state - - scheduler_state = self.scheduler.set_timesteps( - params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape - ) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * params["scheduler"].init_noise_sigma - - if DEBUG: - # run with python for loop - for i in range(num_inference_steps): - latents, scheduler_state = loop_body(i, (latents, scheduler_state)) - else: - latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) - - # scale and decode the image latents with vae - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample - - image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) - return image - - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt_ids: jnp.array, - image: jnp.array, - params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, - num_inference_steps: int = 50, - guidance_scale: Union[float, jnp.array] = 7.5, - latents: jnp.array = None, - neg_prompt_ids: jnp.array = None, - controlnet_conditioning_scale: Union[float, jnp.array] = 1.0, - return_dict: bool = True, - jit: bool = False, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt_ids (`jnp.array`): - The prompt or prompts to guide the image generation. - image (`jnp.array`): - Array representing the ControlNet input condition. ControlNet use this input condition to generate - guidance to Unet. - params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights - prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - latents (`jnp.array`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0): - The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added - to the residual in the original unet. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of - a plain tuple. - jit (`bool`, defaults to `False`): - Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument - exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. - - Examples: - - Returns: - [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a - `tuple. When returning a tuple, the first element is a list with the generated images, and the second - element is a list of `bool`s denoting whether the corresponding generated image likely represents - "not-safe-for-work" (nsfw) content, according to the `safety_checker`. - """ - - height, width = image.shape[-2:] - - if isinstance(guidance_scale, float): - # Convert to a tensor so each device gets a copy. Follow the prompt_ids for - # shape information, as they may be sharded (when `jit` is `True`), or not. - guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) - if len(prompt_ids.shape) > 2: - # Assume sharded - guidance_scale = guidance_scale[:, None] - - if isinstance(controlnet_conditioning_scale, float): - # Convert to a tensor so each device gets a copy. Follow the prompt_ids for - # shape information, as they may be sharded (when `jit` is `True`), or not. - controlnet_conditioning_scale = jnp.array([controlnet_conditioning_scale] * prompt_ids.shape[0]) - if len(prompt_ids.shape) > 2: - # Assume sharded - controlnet_conditioning_scale = controlnet_conditioning_scale[:, None] - - if jit: - images = _p_generate( - self, - prompt_ids, - image, - params, - prng_seed, - num_inference_steps, - guidance_scale, - latents, - neg_prompt_ids, - controlnet_conditioning_scale, - ) - else: - images = self._generate( - prompt_ids, - image, - params, - prng_seed, - num_inference_steps, - guidance_scale, - latents, - neg_prompt_ids, - controlnet_conditioning_scale, - ) - - if self.safety_checker is not None: - safety_params = params["safety_checker"] - images_uint8_casted = (images * 255).round().astype("uint8") - num_devices, batch_size = images.shape[:2] - - images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) - images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) - images = np.asarray(images) - - # block images - if any(has_nsfw_concept): - for i, is_nsfw in enumerate(has_nsfw_concept): - if is_nsfw: - images[i] = np.asarray(images_uint8_casted[i]) - - images = images.reshape(num_devices, batch_size, height, width, 3) - else: - images = np.asarray(images) - has_nsfw_concept = False - - if not return_dict: - return (images, has_nsfw_concept) - - return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) - - -# Static argnums are pipe, num_inference_steps. A change would trigger recompilation. -# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). -@partial( - jax.pmap, - in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0), - static_broadcasted_argnums=(0, 5), +deprecate( + "stable diffusion controlnet", + "0.22.0", + "Importing `FlaxStableDiffusionControlNetPipeline` from diffusers.pipelines.stable_diffusion.flax_pipeline_stable_diffusion_controlnet is deprecated. Please import `from diffusers import FlaxStableDiffusionControlNetPipeline` instead.", + standard_warn=False, + stacklevel=3, ) -def _p_generate( - pipe, - prompt_ids, - image, - params, - prng_seed, - num_inference_steps, - guidance_scale, - latents, - neg_prompt_ids, - controlnet_conditioning_scale, -): - return pipe._generate( - prompt_ids, - image, - params, - prng_seed, - num_inference_steps, - guidance_scale, - latents, - neg_prompt_ids, - controlnet_conditioning_scale, - ) - - -@partial(jax.pmap, static_broadcasted_argnums=(0,)) -def _p_get_has_nsfw_concepts(pipe, features, params): - return pipe._get_has_nsfw_concepts(features, params) - - -def unshard(x: jnp.ndarray): - # einops.rearrange(x, 'd b ... -> (d b) ...') - num_devices, batch_size = x.shape[:2] - rest = x.shape[2:] - return x.reshape(num_devices * batch_size, *rest) - - -def preprocess(image, dtype): - image = image.convert("RGB") - w, h = image.size - w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 - image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) - image = jnp.array(image).astype(dtype) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - return image diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index 1cef221ea6e1..c7555e2ebad4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -12,1093 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import inspect -import os -import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import numpy as np -import PIL.Image -import torch -import torch.nn.functional as F -from torch import nn -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer - -from ...image_processor import VaeImageProcessor -from ...loaders import TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel -from ...models.controlnet import ControlNetOutput -from ...models.modeling_utils import ModelMixin -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - PIL_INTERPOLATION, - is_accelerate_available, - is_accelerate_version, - logging, - randn_tensor, - replace_example_docstring, +# NOTE: This file is deprecated and will be removed in a future version. +# It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works +from ...utils import deprecate +from ..controlnet.multicontrolnet import MultiControlNetModel # noqa: F401 +from ..controlnet.pipeline_controlnet import StableDiffusionControlNetPipeline # noqa: F401 + + +deprecate( + "stable diffusion controlnet", + "0.22.0", + "Importing `StableDiffusionControlNetPipeline` or `MultiControlNetModel` from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet is deprecated. Please import `from diffusers import StableDiffusionControlNetPipeline` instead.", + standard_warn=False, + stacklevel=3, ) -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> # !pip install opencv-python transformers accelerate - >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler - >>> from diffusers.utils import load_image - >>> import numpy as np - >>> import torch - - >>> import cv2 - >>> from PIL import Image - - >>> # download an image - >>> image = load_image( - ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" - ... ) - >>> image = np.array(image) - - >>> # get canny image - >>> image = cv2.Canny(image, 100, 200) - >>> image = image[:, :, None] - >>> image = np.concatenate([image, image, image], axis=2) - >>> canny_image = Image.fromarray(image) - - >>> # load control net and stable diffusion v1-5 - >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) - >>> pipe = StableDiffusionControlNetPipeline.from_pretrained( - ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 - ... ) - - >>> # speed up diffusion process with faster scheduler and memory optimization - >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) - >>> # remove following line if xformers is not installed - >>> pipe.enable_xformers_memory_efficient_attention() - - >>> pipe.enable_model_cpu_offload() - - >>> # generate image - >>> generator = torch.manual_seed(0) - >>> image = pipe( - ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image - ... ).images[0] - ``` -""" - - -class MultiControlNetModel(ModelMixin): - r""" - Multiple `ControlNetModel` wrapper class for Multi-ControlNet - - This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be - compatible with `ControlNetModel`. - - Args: - controlnets (`List[ControlNetModel]`): - Provides additional conditioning to the unet during the denoising process. You must set multiple - `ControlNetModel` as a list. - """ - - def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]): - super().__init__() - self.nets = nn.ModuleList(controlnets) - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - controlnet_cond: List[torch.tensor], - conditioning_scale: List[float], - class_labels: Optional[torch.Tensor] = None, - timestep_cond: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - guess_mode: bool = False, - return_dict: bool = True, - ) -> Union[ControlNetOutput, Tuple]: - for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): - down_samples, mid_sample = controlnet( - sample, - timestep, - encoder_hidden_states, - image, - scale, - class_labels, - timestep_cond, - attention_mask, - cross_attention_kwargs, - guess_mode, - return_dict, - ) - - # merge samples - if i == 0: - down_block_res_samples, mid_block_res_sample = down_samples, mid_sample - else: - down_block_res_samples = [ - samples_prev + samples_curr - for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) - ] - mid_block_res_sample += mid_sample - - return down_block_res_samples, mid_block_res_sample - - -class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin): - r""" - Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - In addition the pipeline inherits the following loading methods: - - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): - Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets - as a list, the outputs from each ControlNet are added together to create one combined additional - conditioning. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPImageProcessor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - _optional_components = ["safety_checker", "feature_extractor"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPImageProcessor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - if isinstance(controlnet, (list, tuple)): - controlnet = MultiControlNetModel(controlnet) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - controlnet=controlnet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. - - When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. - - When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in - several steps. This is useful to save a large amount of memory and to allow the processing of larger images. - """ - self.vae.enable_tiling() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: - cpu_offload(cpu_offloaded_model, device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - - def enable_model_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate import cpu_offload_with_hook - else: - raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") - - device = torch.device(f"cuda:{gpu_id}") - - hook = None - for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: - _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - - if self.safety_checker is not None: - # the safety checker can offload the vae again - _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) - - # control net hook has be manually offloaded as it alternates with unet - cpu_offload_with_hook(self.controlnet, device) - - # We'll offload the last model manually. - self.final_offload_hook = hook - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is None: - has_nsfw_concept = None - else: - if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - else: - feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - warnings.warn( - "The decode_latents method is deprecated and will be removed in a future version. Please" - " use VaeImageProcessor instead", - FutureWarning, - ) - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - image, - height, - width, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - controlnet_conditioning_scale=1.0, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - # `prompt` needs more sophisticated handling when there are multiple - # conditionings. - if isinstance(self.controlnet, MultiControlNetModel): - if isinstance(prompt, list): - logger.warning( - f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" - " prompts. The conditionings will be fixed across the prompts." - ) - - # Check `image` - is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( - self.controlnet, torch._dynamo.eval_frame.OptimizedModule - ) - if ( - isinstance(self.controlnet, ControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetModel) - ): - self.check_image(image, prompt, prompt_embeds) - elif ( - isinstance(self.controlnet, MultiControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, MultiControlNetModel) - ): - if not isinstance(image, list): - raise TypeError("For multiple controlnets: `image` must be type `list`") - - # When `image` is a nested list: - # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) - elif any(isinstance(i, list) for i in image): - raise ValueError("A single batch of multiple conditionings are supported at the moment.") - elif len(image) != len(self.controlnet.nets): - raise ValueError( - "For multiple controlnets: `image` must have the same length as the number of controlnets." - ) - - for image_ in image: - self.check_image(image_, prompt, prompt_embeds) - else: - assert False - - # Check `controlnet_conditioning_scale` - if ( - isinstance(self.controlnet, ControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetModel) - ): - if not isinstance(controlnet_conditioning_scale, float): - raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - elif ( - isinstance(self.controlnet, MultiControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, MultiControlNetModel) - ): - if isinstance(controlnet_conditioning_scale, list): - if any(isinstance(i, list) for i in controlnet_conditioning_scale): - raise ValueError("A single batch of multiple conditionings are supported at the moment.") - elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( - self.controlnet.nets - ): - raise ValueError( - "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" - " the same length as the number of controlnets" - ) - else: - assert False - - def check_image(self, image, prompt, prompt_embeds): - image_is_pil = isinstance(image, PIL.Image.Image) - image_is_tensor = isinstance(image, torch.Tensor) - image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) - image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) - - if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: - raise TypeError( - "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" - ) - - if image_is_pil: - image_batch_size = 1 - elif image_is_tensor: - image_batch_size = image.shape[0] - elif image_is_pil_list: - image_batch_size = len(image) - elif image_is_tensor_list: - image_batch_size = len(image) - - if prompt is not None and isinstance(prompt, str): - prompt_batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - prompt_batch_size = len(prompt) - elif prompt_embeds is not None: - prompt_batch_size = prompt_embeds.shape[0] - - if image_batch_size != 1 and image_batch_size != prompt_batch_size: - raise ValueError( - f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" - ) - - def prepare_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - do_classifier_free_guidance=False, - guess_mode=False, - ): - if not isinstance(image, torch.Tensor): - if isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - images = [] - - for image_ in image: - image_ = image_.convert("RGB") - image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) - image_ = np.array(image_) - image_ = image_[None, :] - images.append(image_) - - image = images - - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance and not guess_mode: - image = torch.cat([image] * 2) - - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - def _default_height_width(self, height, width, image): - # NOTE: It is possible that a list of images have different - # dimensions for each image, so just checking the first image - # is not _exactly_ correct, but it is simple. - while isinstance(image, list): - image = image[0] - - if height is None: - if isinstance(image, PIL.Image.Image): - height = image.height - elif isinstance(image, torch.Tensor): - height = image.shape[2] - - height = (height // 8) * 8 # round down to nearest multiple of 8 - - if width is None: - if isinstance(image, PIL.Image.Image): - width = image.width - elif isinstance(image, torch.Tensor): - width = image.shape[3] - - width = (width // 8) * 8 # round down to nearest multiple of 8 - - return height, width - - # override DiffusionPipeline - def save_pretrained( - self, - save_directory: Union[str, os.PathLike], - safe_serialization: bool = False, - variant: Optional[str] = None, - ): - if isinstance(self.controlnet, ControlNetModel): - super().save_pretrained(save_directory, safe_serialization, variant) - else: - raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.") - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_conditioning_scale: Union[float, List[float]] = 1.0, - guess_mode: bool = False, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, - `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): - The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If - the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can - also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If - height and/or width are passed, `image` is resized according to them. If multiple ControlNets are - specified in init, images must be passed as a list such that each element of the list can be correctly - batched for input to a single controlnet. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): - The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added - to the residual in the original unet. If multiple ControlNets are specified in init, you can set the - corresponding scale as a list. - guess_mode (`bool`, *optional*, defaults to `False`): - In this mode, the ControlNet encoder will try best to recognize the content of the input image even if - you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. - - Examples: - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 0. Default height and width to unet - height, width = self._default_height_width(height, width, image) - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - image, - height, - width, - callback_steps, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, - controlnet_conditioning_scale, - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): - controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets) - - global_pool_conditions = ( - self.controlnet.config.global_pool_conditions - if isinstance(self.controlnet, ControlNetModel) - else self.controlnet.nets[0].config.global_pool_conditions - ) - guess_mode = guess_mode or global_pool_conditions - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 4. Prepare image - is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( - self.controlnet, torch._dynamo.eval_frame.OptimizedModule - ) - if ( - isinstance(self.controlnet, ControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetModel) - ): - image = self.prepare_image( - image=image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=self.controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - guess_mode=guess_mode, - ) - elif ( - isinstance(self.controlnet, MultiControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, MultiControlNetModel) - ): - images = [] - - for image_ in image: - image_ = self.prepare_image( - image=image_, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=self.controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - guess_mode=guess_mode, - ) - - images.append(image_) - - image = images - else: - assert False - - # 5. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 6. Prepare latent variables - num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # controlnet(s) inference - if guess_mode and do_classifier_free_guidance: - # Infer ControlNet only for the conditional batch. - controlnet_latent_model_input = latents - controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] - else: - controlnet_latent_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds - - down_block_res_samples, mid_block_res_sample = self.controlnet( - controlnet_latent_model_input, - t, - encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=image, - conditioning_scale=controlnet_conditioning_scale, - guess_mode=guess_mode, - return_dict=False, - ) - - if guess_mode and do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. - # To apply the output of ControlNet to both the unconditional and conditional batches, - # add 0 to the unconditional batch to keep it unchanged. - down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] - mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - return_dict=False, - )[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # If we do sequential model offloading, let's offload unet and controlnet - # manually for max memory savings - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.unet.to("cpu") - self.controlnet.to("cpu") - torch.cuda.empty_cache() - - if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - else: - image = latents - has_nsfw_concept = None - - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] - else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index f3708107e82a..4c6c595c41d8 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -212,6 +212,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionControlNetImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionControlNetInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/controlnet/__init__.py b/tests/pipelines/controlnet/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py similarity index 98% rename from tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py rename to tests/pipelines/controlnet/test_controlnet.py index bd1470f5ebd1..0453bb38e1ee 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -34,7 +34,10 @@ from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import require_torch_gpu -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..pipeline_params import ( + TEXT_TO_IMAGE_BATCH_PARAMS, + TEXT_TO_IMAGE_PARAMS, +) from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin @@ -42,7 +45,7 @@ torch.use_deterministic_algorithms(True) -class StableDiffusionControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class ControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionControlNetPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -155,6 +158,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt pipeline_class = StableDiffusionControlNetPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess def get_dummy_components(self): torch.manual_seed(0) @@ -307,7 +311,7 @@ def test_save_load_optional_components(self): @slow @require_torch_gpu -class StableDiffusionControlNetPipelineSlowTests(unittest.TestCase): +class ControlNetPipelineSlowTests(unittest.TestCase): def tearDown(self): super().tearDown() gc.collect() diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py new file mode 100644 index 000000000000..b83a8af2778b --- /dev/null +++ b/tests/pipelines/controlnet/test_controlnet_img2img.py @@ -0,0 +1,366 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/ + +import gc +import random +import tempfile +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + ControlNetModel, + DDIMScheduler, + StableDiffusionControlNetImg2ImgPipeline, + UNet2DConditionModel, +) +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel +from diffusers.utils import floats_tensor, load_image, load_numpy, randn_tensor, slow, torch_device +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import require_torch_gpu + +from ..pipeline_params import ( + TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, +) +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False +torch.use_deterministic_algorithms(True) + + +class ControlNetImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableDiffusionControlNetImg2ImgPipeline + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + torch.manual_seed(0) + controlnet = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + ) + torch.manual_seed(0) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + controlnet_embedder_scale_factor = 2 + control_image = randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ) + image = floats_tensor(control_image.shape, rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + "image": image, + "control_image": control_image, + } + + return inputs + + def test_attention_slicing_forward_pass(self): + return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3) + + +class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableDiffusionControlNetImg2ImgPipeline + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + torch.manual_seed(0) + controlnet1 = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + ) + torch.manual_seed(0) + controlnet2 = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + ) + torch.manual_seed(0) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + controlnet = MultiControlNetModel([controlnet1, controlnet2]) + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + controlnet_embedder_scale_factor = 2 + + control_image = [ + randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ), + randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ), + ] + + image = floats_tensor(control_image[0].shape, rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + "image": image, + "control_image": control_image, + } + + return inputs + + def test_attention_slicing_forward_pass(self): + return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3) + + def test_save_pretrained_raise_not_implemented_exception(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + with tempfile.TemporaryDirectory() as tmpdir: + try: + # save_pretrained is not implemented for Multi-ControlNet + pipe.save_pretrained(tmpdir) + except NotImplementedError: + pass + + # override PipelineTesterMixin + @unittest.skip("save pretrained not implemented") + def test_save_load_float16(self): + ... + + # override PipelineTesterMixin + @unittest.skip("save pretrained not implemented") + def test_save_load_local(self): + ... + + # override PipelineTesterMixin + @unittest.skip("save pretrained not implemented") + def test_save_load_optional_components(self): + ... + + +@slow +@require_torch_gpu +class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_canny(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") + + pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "evil space-punk bird" + control_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ).resize((512, 512)) + image = load_image( + "https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png" + ).resize((512, 512)) + + output = pipe( + prompt, + image, + control_image=control_image, + generator=generator, + output_type="np", + num_inference_steps=50, + strength=0.6, + ) + + image = output.images[0] + + assert image.shape == (512, 512, 3) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/img2img.npy" + ) + + assert np.abs(expected_image - image).max() < 9e-2 diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py new file mode 100644 index 000000000000..786b0e608ef0 --- /dev/null +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -0,0 +1,379 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This model implementation is heavily based on: + +import gc +import random +import tempfile +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + ControlNetModel, + DDIMScheduler, + StableDiffusionControlNetInpaintPipeline, + UNet2DConditionModel, +) +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel +from diffusers.utils import floats_tensor, load_image, load_numpy, randn_tensor, slow, torch_device +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import require_torch_gpu + +from ..pipeline_params import ( + TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_INPAINTING_PARAMS, +) +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False +torch.use_deterministic_algorithms(True) + + +class ControlNetInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableDiffusionControlNetInpaintPipeline + params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS + batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + image_params = frozenset([]) + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=9, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + torch.manual_seed(0) + controlnet = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + ) + torch.manual_seed(0) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + controlnet_embedder_scale_factor = 2 + control_image = randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ) + init_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + init_image = init_image.cpu().permute(0, 2, 3, 1)[0] + + image = Image.fromarray(np.uint8(init_image)).convert("RGB").resize((64, 64)) + mask_image = Image.fromarray(np.uint8(init_image + 4)).convert("RGB").resize((64, 64)) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + "image": image, + "mask_image": mask_image, + "control_image": control_image, + } + + return inputs + + def test_attention_slicing_forward_pass(self): + return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3) + + +class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableDiffusionControlNetInpaintPipeline + params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS + batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=9, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + torch.manual_seed(0) + controlnet1 = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + ) + torch.manual_seed(0) + controlnet2 = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + ) + torch.manual_seed(0) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + controlnet = MultiControlNetModel([controlnet1, controlnet2]) + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + controlnet_embedder_scale_factor = 2 + + control_image = [ + randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ), + randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ), + ] + init_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + init_image = init_image.cpu().permute(0, 2, 3, 1)[0] + + image = Image.fromarray(np.uint8(init_image)).convert("RGB").resize((64, 64)) + mask_image = Image.fromarray(np.uint8(init_image + 4)).convert("RGB").resize((64, 64)) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + "image": image, + "mask_image": mask_image, + "control_image": control_image, + } + + return inputs + + def test_attention_slicing_forward_pass(self): + return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3) + + def test_save_pretrained_raise_not_implemented_exception(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + with tempfile.TemporaryDirectory() as tmpdir: + try: + # save_pretrained is not implemented for Multi-ControlNet + pipe.save_pretrained(tmpdir) + except NotImplementedError: + pass + + # override PipelineTesterMixin + @unittest.skip("save pretrained not implemented") + def test_save_load_float16(self): + ... + + # override PipelineTesterMixin + @unittest.skip("save pretrained not implemented") + def test_save_load_local(self): + ... + + # override PipelineTesterMixin + @unittest.skip("save pretrained not implemented") + def test_save_load_optional_components(self): + ... + + +@slow +@require_torch_gpu +class ControlNetInpaintPipelineSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_canny(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") + + pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", safety_checker=None, controlnet=controlnet + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + image = load_image( + "https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png" + ).resize((512, 512)) + + mask_image = load_image( + "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main" + "/stable_diffusion_inpaint/input_bench_mask.png" + ).resize((512, 512)) + + prompt = "pitch black hole" + + control_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ).resize((512, 512)) + + output = pipe( + prompt, + image=image, + mask_image=mask_image, + control_image=control_image, + generator=generator, + output_type="np", + num_inference_steps=3, + ) + + image = output.images[0] + + assert image.shape == (512, 512, 3) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/inpaint.npy" + ) + + assert np.abs(expected_image - image).max() < 9e-2 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py b/tests/pipelines/controlnet/test_flax_controlnet.py similarity index 98% rename from tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py rename to tests/pipelines/controlnet/test_flax_controlnet.py index 268c01320177..4ad75b407acc 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py +++ b/tests/pipelines/controlnet/test_flax_controlnet.py @@ -30,7 +30,7 @@ @slow @require_flax -class FlaxStableDiffusionControlNetPipelineIntegrationTests(unittest.TestCase): +class FlaxControlNetPipelineIntegrationTests(unittest.TestCase): def tearDown(self): # clean up the VRAM after each test super().tearDown() diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py index 8c27a568d24d..0ce55ae78ae0 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py @@ -46,9 +46,8 @@ class StableDiffusionImageVariationPipelineFastTests( pipeline_class = StableDiffusionImageVariationPipeline params = IMAGE_VARIATION_PARAMS batch_params = IMAGE_VARIATION_BATCH_PARAMS - image_params = frozenset( - [] - ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess + image_params = frozenset([]) + # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index cdf138c4e178..a215e4da6697 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -47,9 +47,8 @@ class StableDiffusionInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipelin pipeline_class = StableDiffusionInpaintPipeline params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS - image_params = frozenset( - [] - ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess + image_params = frozenset([]) + # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess def get_dummy_components(self): torch.manual_seed(0) From 17f9aed79cd073f4475bd3af1c6f34b681839685 Mon Sep 17 00:00:00 2001 From: clarencechen Date: Tue, 16 May 2023 11:26:53 -0700 Subject: [PATCH 05/33] [Scheduler] DPM-Solver (++) Inverse Scheduler (#3335) * Add DPM-Solver Multistep Inverse Scheduler * Add draft tests for DiffEdit * Add inverse sde-dpmsolver steps to tune image diversity from inverted latents * Fix tests --------- Co-authored-by: Patrick von Platen --- docs/source/en/_toctree.yml | 2 + .../multistep_dpm_solver_inverse.mdx | 22 + src/diffusers/__init__.py | 1 + src/diffusers/schedulers/__init__.py | 1 + .../scheduling_dpmsolver_multistep_inverse.py | 701 ++++++++++++++++++ src/diffusers/utils/dummy_pt_objects.py | 15 + .../test_stable_diffusion_diffedit.py | 77 ++ 7 files changed, 819 insertions(+) create mode 100644 docs/source/en/api/schedulers/multistep_dpm_solver_inverse.mdx create mode 100644 src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 52d8988206f1..645cbb04c1d0 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -252,6 +252,8 @@ title: Euler scheduler - local: api/schedulers/heun title: Heun Scheduler + - local: api/schedulers/multistep_dpm_solver_inverse + title: Inverse Multistep DPM-Solver - local: api/schedulers/ipndm title: IPNDM - local: api/schedulers/lms_discrete diff --git a/docs/source/en/api/schedulers/multistep_dpm_solver_inverse.mdx b/docs/source/en/api/schedulers/multistep_dpm_solver_inverse.mdx new file mode 100644 index 000000000000..1b3348a5a3ea --- /dev/null +++ b/docs/source/en/api/schedulers/multistep_dpm_solver_inverse.mdx @@ -0,0 +1,22 @@ + + +# Inverse Multistep DPM-Solver (DPMSolverMultistepInverse) + +## Overview + +This scheduler is the inverted scheduler of [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps](https://arxiv.org/abs/2206.00927) and [DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models +](https://arxiv.org/abs/2211.01095) by Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu. +The implementation is mostly based on the DDIM inversion definition of [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/pdf/2211.09794.pdf) and the ad-hoc notebook implementation for DiffEdit latent inversion [here](https://github.com/Xiang-cd/DiffEdit-stable-diffusion/blob/main/diffedit.ipynb). + +## DPMSolverMultistepInverseScheduler +[[autodoc]] DPMSolverMultistepInverseScheduler diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0d48a16b6216..9b3f8adad376 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -76,6 +76,7 @@ DDIMScheduler, DDPMScheduler, DEISMultistepScheduler, + DPMSolverMultistepInverseScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index c4b62c722257..05414e32fc9e 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -33,6 +33,7 @@ from .scheduling_ddpm import DDPMScheduler from .scheduling_deis_multistep import DEISMultistepScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler + from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from .scheduling_euler_discrete import EulerDiscreteScheduler diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py new file mode 100644 index 000000000000..b424ebbff262 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -0,0 +1,701 @@ +# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): + """ + DPMSolverMultistepInverseScheduler is the reverse scheduler of [`DPMSolverMultistepScheduler`]. + + We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space + diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic + thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + solver_order (`int`, default `2`): + the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to + use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion + models (such as stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++`. + algorithm_type (`str`, default `dpmsolver++`): + the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or + `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and + the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use + `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion). + solver_type (`str`, default `midpoint`): + the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects + the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are + slightly better, so we recommend to use the `midpoint` type. + lower_order_final (`bool`, default `True`): + whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically + find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + lambda_min_clipped (`float`, default `-inf`): + the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for + cosine (squaredcos_cap_v2) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + use_karras_sigmas: Optional[bool] = False, + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32).copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self.use_karras_sigmas = use_karras_sigmas + + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped) + self.noisiest_timestep = self.config.num_train_timesteps - 1 - clipped_idx + timesteps = ( + np.linspace(0, self.noisiest_timestep, num_inference_steps + 1).round()[:-1].copy().astype(np.int64) + ) + + if self.use_karras_sigmas: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps = timesteps.copy().astype(np.int64) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output + def convert_model_output( + self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + ) -> torch.FloatTensor: + """ + Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. + + DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to + discretize an integral of the data prediction model. So we need to first convert the model output to the + corresponding type to match the algorithm. + + Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or + DPM-Solver++ for both noise prediction model and data prediction model. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the converted model output. + """ + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + model_output = model_output[:, :3] + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update + def dpm_solver_first_order_update( + self, + model_output: torch.FloatTensor, + timestep: int, + prev_timestep: int, + sample: torch.FloatTensor, + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + One step for the first-order DPM-Solver (equivalent to DDIM). + + See https://arxiv.org/abs/2206.00927 for the detailed derivation. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] + alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] + sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.FloatTensor], + timestep_list: List[int], + prev_timestep: int, + sample: torch.FloatTensor, + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + One step for the second-order multistep DPM-Solver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + m0, m1 = model_output_list[-1], model_output_list[-2] + lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.FloatTensor], + timestep_list: List[int], + prev_timestep: int, + sample: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + One step for the third-order multistep DPM-Solver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( + self.lambda_t[t], + self.lambda_t[s0], + self.lambda_t[s1], + self.lambda_t[s2], + ) + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the multistep DPM-Solver. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero() + if len(step_index) == 0: + step_index = len(self.timesteps) - 1 + else: + step_index = step_index.item() + prev_timestep = ( + self.noisiest_timestep if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + ) + lower_order_final = ( + (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + lower_order_second = ( + (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, timestep, sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update( + model_output, timestep, prev_timestep, sample, noise=noise + ) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + timestep_list = [self.timesteps[step_index - 1], timestep] + prev_sample = self.multistep_dpm_solver_second_order_update( + self.model_outputs, timestep_list, prev_timestep, sample, noise=noise + ) + else: + timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] + prev_sample = self.multistep_dpm_solver_third_order_update( + self.model_outputs, timestep_list, prev_timestep, sample + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 014e193aa32a..e07b7cb27da7 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -450,6 +450,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DPMSolverMultistepInverseScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DPMSolverMultistepScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py index d32f4d665f55..c9da7b06893f 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py @@ -27,6 +27,8 @@ AutoencoderKL, DDIMInverseScheduler, DDIMScheduler, + DPMSolverMultistepInverseScheduler, + DPMSolverMultistepScheduler, StableDiffusionDiffEditPipeline, UNet2DConditionModel, ) @@ -256,6 +258,30 @@ def test_inversion(self): def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=5e-3) + def test_inversion_dpm(self): + device = "cpu" + + components = self.get_dummy_components() + + scheduler_args = {"beta_start": 0.00085, "beta_end": 0.012, "beta_schedule": "scaled_linear"} + components["scheduler"] = DPMSolverMultistepScheduler(**scheduler_args) + components["inverse_scheduler"] = DPMSolverMultistepInverseScheduler(**scheduler_args) + + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inversion_inputs(device) + image = pipe.invert(**inputs).images + image_slice = image[0, -1, -3:, -3:] + + self.assertEqual(image.shape, (2, 32, 32, 3)) + expected_slice = np.array( + [0.5150, 0.5134, 0.5043, 0.5376, 0.4694, 0.51050, 0.5015, 0.4407, 0.4799], + ) + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + self.assertLessEqual(max_diff, 1e-3) + @require_torch_gpu @slow @@ -320,3 +346,54 @@ def test_stable_diffusion_diffedit_full(self): / 255 ) assert np.abs((expected_image - image).max()) < 5e-1 + + def test_stable_diffusion_diffedit_dpm(self): + generator = torch.manual_seed(0) + + pipe = StableDiffusionDiffEditPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", safety_checker=None, torch_dtype=torch.float16 + ) + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe.inverse_scheduler = DPMSolverMultistepInverseScheduler.from_config(pipe.scheduler.config) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + source_prompt = "a bowl of fruit" + target_prompt = "a bowl of pears" + + mask_image = pipe.generate_mask( + image=self.raw_image, + source_prompt=source_prompt, + target_prompt=target_prompt, + generator=generator, + ) + + inv_latents = pipe.invert( + prompt=source_prompt, + image=self.raw_image, + inpaint_strength=0.7, + generator=generator, + num_inference_steps=25, + ).latents + + image = pipe( + prompt=target_prompt, + mask_image=mask_image, + image_latents=inv_latents, + generator=generator, + negative_prompt=source_prompt, + inpaint_strength=0.7, + num_inference_steps=25, + output_type="numpy", + ).images[0] + + expected_image = ( + np.array( + load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/diffedit/pears.png" + ).resize((768, 768)) + ) + / 255 + ) + assert np.abs((expected_image - image).max()) < 5e-1 From 754fac82d2e0237edff20c4eee3f0f2ea4ab91a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Tue, 16 May 2023 20:33:34 +0200 Subject: [PATCH 06/33] [Docs] Fix incomplete docstring for resnet.py (#3438) Fix incomplete docstrings for resnet.py --- src/diffusers/models/resnet.py | 86 ++++++++++++++++++++++++---------- 1 file changed, 62 insertions(+), 24 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index d9d539959c09..debe120e8ead 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -24,14 +24,17 @@ class Upsample1D(nn.Module): - """ - An upsampling layer with an optional convolution. + """A 1D upsampling layer with an optional convolution. Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - use_conv_transpose: - out_channels: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. """ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): @@ -62,14 +65,17 @@ def forward(self, x): class Downsample1D(nn.Module): - """ - A downsampling layer with an optional convolution. + """A 1D downsampling layer with an optional convolution. Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - out_channels: - padding: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + padding (`int`, default `1`): + padding for the convolution. """ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): @@ -93,14 +99,17 @@ def forward(self, x): class Upsample2D(nn.Module): - """ - An upsampling layer with an optional convolution. + """A 2D upsampling layer with an optional convolution. Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - use_conv_transpose: - out_channels: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. """ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): @@ -162,14 +171,17 @@ def forward(self, hidden_states, output_size=None): class Downsample2D(nn.Module): - """ - A downsampling layer with an optional convolution. + """A 2D downsampling layer with an optional convolution. Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - out_channels: - padding: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + padding (`int`, default `1`): + padding for the convolution. """ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): @@ -209,6 +221,19 @@ def forward(self, hidden_states): class FirUpsample2D(nn.Module): + """A 2D FIR upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + fir_kernel (`tuple`, default `(1, 3, 3, 1)`): + kernel for the FIR filter. + """ + def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): super().__init__() out_channels = out_channels if out_channels else channels @@ -309,6 +334,19 @@ def forward(self, hidden_states): class FirDownsample2D(nn.Module): + """A 2D FIR downsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + fir_kernel (`tuple`, default `(1, 3, 3, 1)`): + kernel for the FIR filter. + """ + def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): super().__init__() out_channels = out_channels if out_channels else channels From 92ea5baca2815ecd51f96bedb0fb766b313196f8 Mon Sep 17 00:00:00 2001 From: superlabs-dev <133080491+superlabs-dev@users.noreply.github.com> Date: Wed, 17 May 2023 03:33:47 +0900 Subject: [PATCH 07/33] fix tiled vae blend extent range (#3384) fix tiled vae bleand extent range --- src/diffusers/models/autoencoder_kl.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 1a8a204d80ce..a4894e78c43f 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -196,12 +196,14 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode return DecoderOutput(sample=decoded) def blend_v(self, a, b, blend_extent): - for y in range(min(a.shape[2], b.shape[2], blend_extent)): + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) return b def blend_h(self, a, b, blend_extent): - for x in range(min(a.shape[3], b.shape[3], blend_extent)): + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) return b From 0392eceba8d42b24fcecc56b2cc1f4582dbefcc4 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 16 May 2023 20:35:47 +0200 Subject: [PATCH 08/33] Small update to "Next steps" section (#3443) Small update to "Next steps" section: - PyTorch 2 is recommended. - Updated improvement figures. --- docs/source/en/stable_diffusion.mdx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/en/stable_diffusion.mdx b/docs/source/en/stable_diffusion.mdx index d02e93033614..64c90c7f6477 100644 --- a/docs/source/en/stable_diffusion.mdx +++ b/docs/source/en/stable_diffusion.mdx @@ -266,6 +266,6 @@ image_grid(images) In this tutorial, you learned how to optimize a [`DiffusionPipeline`] for computational and memory efficiency as well as improving the quality of generated outputs. If you're interested in making your pipeline even faster, take a look at the following resources: -- Enable [xFormers](./optimization/xformers) memory efficient attention mechanism for faster speed and reduced memory consumption. -- Learn how in [PyTorch 2.0](./optimization/torch2.0), [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) can yield 2-9% faster inference speed. -- Many optimization techniques for inference are also included in this memory and speed [guide](./optimization/fp16), such as memory offloading. +- Learn how [PyTorch 2.0](./optimization/torch2.0) and [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) can yield 5 - 300% faster inference speed. +- If you can't use PyTorch 2, we recommend you install [xFormers](./optimization/xformers). Its memory-efficient attention mechanism works great with PyTorch 1.13.1 for faster speed and reduced memory consumption. +- Other optimization techniques, such as model offloading, are covered in [this guide](./optimization/fp16). From 6070b32fcfd13fdf81547c91f9333fb117bc3982 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 17 May 2023 07:51:07 +0530 Subject: [PATCH 09/33] Allow arbitrary aspect ratio in IFSuperResolutionPipeline (#3298) * Update pipeline_if_superresolution.py Allow arbitrary aspect ratio in IFSuperResolutionPipeline by using the input image shape * IFSuperResolutionPipeline: allow the user to override the height and width through the arguments * update IFSuperResolutionPipeline width/height doc string to match StableDiffusionInpaintPipeline conventions --------- Co-authored-by: Patrick von Platen --- .../deepfloyd_if/pipeline_if_superresolution.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py index 1ba8f888a8e3..2fe8e6a9d5d5 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py @@ -695,6 +695,8 @@ def preprocess_image(self, image, num_images_per_prompt, device): def __call__( self, prompt: Union[str, List[str]] = None, + height: int = None, + width: int = None, image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None, num_inference_steps: int = 50, timesteps: List[int] = None, @@ -720,6 +722,10 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`): The image to be upscaled. num_inference_steps (`int`, *optional*, defaults to 50): @@ -806,8 +812,8 @@ def __call__( # 2. Define call parameters - height = self.unet.config.sample_size - width = self.unet.config.sample_size + height = height or self.unet.config.sample_size + width = width or self.unet.config.sample_size device = self._execution_device From c09c4f3ab7ab7d46727949e003facb391e1e8b8d Mon Sep 17 00:00:00 2001 From: Rupert Menneer <71332436+rupertmenneer@users.noreply.github.com> Date: Wed, 17 May 2023 03:05:16 -0700 Subject: [PATCH 10/33] Adding 'strength' parameter to StableDiffusionInpaintingPipeline (#3424) * Added explanation of 'strength' parameter * Added get_timesteps function which relies on new strength parameter * Added `strength` parameter which defaults to 1. * Swapped ordering so `noise_timestep` can be calculated before masking the image this is required when you aren't applying 100% noise to the masked region, e.g. strength < 1. * Added strength to check_inputs, throws error if out of range * Changed `prepare_latents` to initialise latents w.r.t strength inspired from the stable diffusion img2img pipeline, init latents are initialised by converting the init image into a VAE latent and adding noise (based upon the strength parameter passed in), e.g. random when strength = 1, or the init image at strength = 0. * WIP: Added a unit test for the new strength parameter in the StableDiffusionInpaintingPipeline still need to add correct regression values * Created a is_strength_max to initialise from pure random noise * Updated unit tests w.r.t new strength parameter + fixed new strength unit test * renamed parameter to avoid confusion with variable of same name * Updated regression values for new strength test - now passes * removed 'copied from' comment as this method is now different and divergent from the cpy * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py Co-authored-by: Patrick von Platen * Ensure backwards compatibility for prepare_mask_and_masked_image created a return_image boolean and initialised to false * Ensure backwards compatibility for prepare_latents * Fixed copy check typo * Fixes w.r.t backward compibility changes * make style * keep function argument ordering same for backwards compatibility in callees with copied from statements * make fix-copies --------- Co-authored-by: Patrick von Platen Co-authored-by: William Berman --- .../controlnet/pipeline_controlnet_inpaint.py | 47 +++++++- .../pipeline_stable_diffusion_inpaint.py | 93 +++++++++++++-- .../test_stable_diffusion_inpaint.py | 108 ++++++++++++++---- 3 files changed, 211 insertions(+), 37 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index a146a1cc2908..27475dc5ef8b 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -99,7 +99,7 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image -def prepare_mask_and_masked_image(image, mask, height, width): +def prepare_mask_and_masked_image(image, mask, height, width, return_image=False): """ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the @@ -209,6 +209,10 @@ def prepare_mask_and_masked_image(image, mask, height, width): masked_image = image * (mask < 0.5) + # n.b. ensure backwards compatibility as old function does not return image + if return_image: + return mask, masked_image, image + return mask, masked_image @@ -795,7 +799,20 @@ def prepare_control_image( return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + ): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -803,13 +820,37 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if is_strength_max: + # if strength is 100% then simply initialise the latents to noise + latents = noise + else: + # otherwise initialise latents as init image + noise + image = image.to(device=device, dtype=dtype) + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(batch_size) + ] + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + latents = self.scheduler.add_noise(image_latents, noise, timestep) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma + return latents def _default_height_width(self, height, width, image): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 518a9a3e9781..78ef11587b4d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -36,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def prepare_mask_and_masked_image(image, mask, height, width): +def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False): """ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the @@ -146,6 +146,10 @@ def prepare_mask_and_masked_image(image, mask, height, width): masked_image = image * (mask < 0.5) + # n.b. ensure backwards compatibility as old function does not return image + if return_image: + return mask, masked_image, image + return mask, masked_image @@ -552,17 +556,20 @@ def decode_latents(self, latents): image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, + strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -600,8 +607,20 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + ): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -609,13 +628,37 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if is_strength_max: + # if strength is 100% then simply initialise the latents to noise + latents = noise + else: + # otherwise initialise latents as init image + noise + image = image.to(device=device, dtype=dtype) + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(batch_size) + ] + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + latents = self.scheduler.add_noise(image_latents, noise, timestep) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma + return latents def prepare_mask_latents( @@ -669,6 +712,16 @@ def prepare_mask_latents( masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + @torch.no_grad() def __call__( self, @@ -677,6 +730,7 @@ def __call__( mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, height: Optional[int] = None, width: Optional[int] = None, + strength: float = 1.0, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -710,6 +764,13 @@ def __call__( The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 1.): + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -802,6 +863,7 @@ def __call__( prompt, height, width, + strength, callback_steps, negative_prompt, prompt_embeds, @@ -833,12 +895,20 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) - # 4. Preprocess mask and image - resizes image and mask w.r.t height and width - mask, masked_image = prepare_mask_and_masked_image(image, mask_image, height, width) - - # 5. set timesteps + # 4. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image + mask, masked_image, init_image = prepare_mask_and_masked_image( + image, mask_image, height, width, return_image=True + ) # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels @@ -851,6 +921,9 @@ def __call__( device, generator, latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, ) # 7. Prepare mask latent variables diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index a215e4da6697..5c5e4c4590dc 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -324,6 +324,26 @@ def test_stable_diffusion_inpaint_pil_input_resolution_test(self): # verify that the returned image has the same height and width as the input height and width assert image.shape == (1, inputs["height"], inputs["width"], 3) + def test_stable_diffusion_inpaint_strength_test(self): + pipe = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", safety_checker=None + ) + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs(torch_device) + # change input strength + inputs["strength"] = 0.75 + image = pipe(**inputs).images + # verify that the returned image has the same height and width as the input height and width + assert image.shape == (1, 512, 512, 3) + + image_slice = image[0, 253:256, 253:256, -1].flatten() + expected_slice = np.array([0.0021, 0.2350, 0.3712, 0.0575, 0.2485, 0.3451, 0.1857, 0.3156, 0.3943]) + assert np.abs(expected_slice - image_slice).max() < 3e-3 + @nightly @require_torch_gpu @@ -427,24 +447,30 @@ def test_pil_inputs(self): mask = np.random.randint(0, 255, (height, width), dtype=np.uint8) > 127.5 mask = Image.fromarray((mask * 255).astype(np.uint8)) - t_mask, t_masked = prepare_mask_and_masked_image(im, mask, height, width) + t_mask, t_masked, t_image = prepare_mask_and_masked_image(im, mask, height, width, return_image=True) self.assertTrue(isinstance(t_mask, torch.Tensor)) self.assertTrue(isinstance(t_masked, torch.Tensor)) + self.assertTrue(isinstance(t_image, torch.Tensor)) self.assertEqual(t_mask.ndim, 4) self.assertEqual(t_masked.ndim, 4) + self.assertEqual(t_image.ndim, 4) self.assertEqual(t_mask.shape, (1, 1, height, width)) self.assertEqual(t_masked.shape, (1, 3, height, width)) + self.assertEqual(t_image.shape, (1, 3, height, width)) self.assertTrue(t_mask.dtype == torch.float32) self.assertTrue(t_masked.dtype == torch.float32) + self.assertTrue(t_image.dtype == torch.float32) self.assertTrue(t_mask.min() >= 0.0) self.assertTrue(t_mask.max() <= 1.0) self.assertTrue(t_masked.min() >= -1.0) self.assertTrue(t_masked.min() <= 1.0) + self.assertTrue(t_image.min() >= -1.0) + self.assertTrue(t_image.min() >= -1.0) self.assertTrue(t_mask.sum() > 0.0) @@ -467,11 +493,16 @@ def test_np_inputs(self): ) mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8)) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) - t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil, height, width) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( + im_np, mask_np, height, width, return_image=True + ) + t_mask_pil, t_masked_pil, t_image_pil = prepare_mask_and_masked_image( + im_pil, mask_pil, height, width, return_image=True + ) self.assertTrue((t_mask_np == t_mask_pil).all()) self.assertTrue((t_masked_np == t_masked_pil).all()) + self.assertTrue((t_image_np == t_image_pil).all()) def test_torch_3D_2D_inputs(self): height, width = 32, 32 @@ -501,13 +532,16 @@ def test_torch_3D_2D_inputs(self): im_np = im_tensor.numpy().transpose(1, 2, 0) mask_np = mask_tensor.numpy() - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True + ) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( + im_np, mask_np, height, width, return_image=True ) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_torch_3D_3D_inputs(self): height, width = 32, 32 @@ -538,13 +572,16 @@ def test_torch_3D_3D_inputs(self): im_np = im_tensor.numpy().transpose(1, 2, 0) mask_np = mask_tensor.numpy()[0] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True + ) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( + im_np, mask_np, height, width, return_image=True ) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_torch_4D_2D_inputs(self): height, width = 32, 32 @@ -575,13 +612,16 @@ def test_torch_4D_2D_inputs(self): im_np = im_tensor.numpy()[0].transpose(1, 2, 0) mask_np = mask_tensor.numpy() - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True + ) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( + im_np, mask_np, height, width, return_image=True ) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_torch_4D_3D_inputs(self): height, width = 32, 32 @@ -613,13 +653,16 @@ def test_torch_4D_3D_inputs(self): im_np = im_tensor.numpy()[0].transpose(1, 2, 0) mask_np = mask_tensor.numpy()[0] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True + ) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( + im_np, mask_np, height, width, return_image=True ) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_torch_4D_4D_inputs(self): height, width = 32, 32 @@ -652,13 +695,16 @@ def test_torch_4D_4D_inputs(self): im_np = im_tensor.numpy()[0].transpose(1, 2, 0) mask_np = mask_tensor.numpy()[0][0] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True + ) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( + im_np, mask_np, height, width, return_image=True ) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_torch_batch_4D_3D(self): height, width = 32, 32 @@ -691,15 +737,17 @@ def test_torch_batch_4D_3D(self): im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] mask_nps = [mask.numpy() for mask in mask_tensor] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True ) - nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)] + nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)] t_mask_np = torch.cat([n[0] for n in nps]) t_masked_np = torch.cat([n[1] for n in nps]) + t_image_np = torch.cat([n[2] for n in nps]) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_torch_batch_4D_4D(self): height, width = 32, 32 @@ -733,15 +781,17 @@ def test_torch_batch_4D_4D(self): im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] mask_nps = [mask.numpy()[0] for mask in mask_tensor] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True ) - nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)] + nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)] t_mask_np = torch.cat([n[0] for n in nps]) t_masked_np = torch.cat([n[1] for n in nps]) + t_image_np = torch.cat([n[2] for n in nps]) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_shape_mismatch(self): height, width = 32, 32 @@ -757,6 +807,7 @@ def test_shape_mismatch(self): torch.randn(64, 64), height, width, + return_image=True, ) # test batch dim with self.assertRaises(AssertionError): @@ -770,6 +821,7 @@ def test_shape_mismatch(self): torch.randn(4, 64, 64), height, width, + return_image=True, ) # test batch dim with self.assertRaises(AssertionError): @@ -783,6 +835,7 @@ def test_shape_mismatch(self): torch.randn(4, 1, 64, 64), height, width, + return_image=True, ) def test_type_mismatch(self): @@ -803,6 +856,7 @@ def test_type_mismatch(self): ).numpy(), height, width, + return_image=True, ) # test tensors-only with self.assertRaises(TypeError): @@ -819,6 +873,7 @@ def test_type_mismatch(self): ), height, width, + return_image=True, ) def test_channels_first(self): @@ -835,6 +890,7 @@ def test_channels_first(self): ), height, width, + return_image=True, ) def test_tensor_range(self): @@ -855,6 +911,7 @@ def test_tensor_range(self): ), height, width, + return_image=True, ) # test im >= -1 with self.assertRaises(ValueError): @@ -871,6 +928,7 @@ def test_tensor_range(self): ), height, width, + return_image=True, ) # test mask <= 1 with self.assertRaises(ValueError): @@ -887,6 +945,7 @@ def test_tensor_range(self): * 2, height, width, + return_image=True, ) # test mask >= 0 with self.assertRaises(ValueError): @@ -903,4 +962,5 @@ def test_tensor_range(self): * -1, height, width, + return_image=True, ) From 415c616712d82fff64df739aae79ec5fce01f045 Mon Sep 17 00:00:00 2001 From: Vimarsh Chaturvedi Date: Wed, 17 May 2023 12:05:33 +0200 Subject: [PATCH 11/33] [WIP] Bugfix - Pipeline.from_pretrained is broken when the pipeline is partially downloaded (#3448) Added bugfix using f strings. --- src/diffusers/pipelines/pipeline_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index a4d3dd1f1673..fa71a181f521 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1249,7 +1249,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: # allow all patterns from non-model folders # this enables downloading schedulers, tokenizers, ... - allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names] + allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names] # also allow downloading config.json files with the model allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names] From 15f1bab13bf3d9ca956d2398e1f550c840fa2bb1 Mon Sep 17 00:00:00 2001 From: 7eu7d7 <31194890+7eu7d7@users.noreply.github.com> Date: Wed, 17 May 2023 18:06:04 +0800 Subject: [PATCH 12/33] Fix gradient checkpointing bugs in freezing part of models (requires_grad=False) (#3404) * gradient checkpointing bug fix * bug fix; changes for reviews * reformat * reformat --------- Co-authored-by: Patrick von Platen --- src/diffusers/models/unet_2d_blocks.py | 173 ++++++++++++++---- src/diffusers/models/vae.py | 46 +++-- .../versatile_diffusion/modeling_text_unet.py | 76 ++++++-- 3 files changed, 230 insertions(+), 65 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 0004f074c563..7b76dd7e37bd 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -18,6 +18,7 @@ import torch.nn.functional as F from torch import nn +from ..utils import is_torch_version from .attention import AdaGroupNorm from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from .dual_transformer_2d import DualTransformer2DModel @@ -866,13 +867,27 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + use_reentrant=False, + )[0] + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( @@ -957,7 +972,14 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -1361,7 +1383,14 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -1558,7 +1587,14 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -1653,14 +1689,29 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - attention_mask, - cross_attention_kwargs, - ) + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + attention_mask, + cross_attention_kwargs, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + attention_mask, + cross_attention_kwargs, + ) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( @@ -1874,13 +1925,27 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + use_reentrant=False, + )[0] + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( @@ -1960,7 +2025,14 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -2388,7 +2460,14 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -2593,7 +2672,14 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -2714,14 +2800,29 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - attention_mask, - cross_attention_kwargs, - )[0] + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + attention_mask, + cross_attention_kwargs, + use_reentrant=False, + )[0] + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + attention_mask, + cross_attention_kwargs, + )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 400c3030af90..6f8514f28d33 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn -from ..utils import BaseOutput, randn_tensor +from ..utils import BaseOutput, is_torch_version, randn_tensor from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @@ -117,11 +117,20 @@ def custom_forward(*inputs): return custom_forward # down - for down_block in self.down_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) - - # middle - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + if is_torch_version(">=", "1.11.0"): + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), sample, use_reentrant=False + ) + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, use_reentrant=False + ) + else: + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) + # middle + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) else: # down @@ -221,13 +230,26 @@ def custom_forward(*inputs): return custom_forward - # middle - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) - sample = sample.to(upscale_dtype) + if is_torch_version(">=", "1.11.0"): + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, use_reentrant=False + ) + sample = sample.to(upscale_dtype) - # up - for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), sample, use_reentrant=False + ) + else: + # middle + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) else: # middle sample = self.mid_block(sample) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index f0a210339c46..7aaa0e49e1da 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -18,7 +18,7 @@ from ...models.embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps from ...models.transformer_2d import Transformer2DModel from ...models.unet_2d_condition import UNet2DConditionOutput -from ...utils import logging +from ...utils import is_torch_version, logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -1077,7 +1077,14 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -1198,13 +1205,27 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + use_reentrant=False, + )[0] + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( @@ -1289,7 +1310,14 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -1412,13 +1440,27 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + use_reentrant=False, + )[0] + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( From 3ebd2d1f9ec97f8bc0fc5cc8c7313bdf5f0dc1d2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 17 May 2023 12:20:13 +0200 Subject: [PATCH 13/33] Make dreambooth lora more robust to orig unet (#3462) * Make dreambooth lora more robust to orig unet * up --- examples/dreambooth/train_dreambooth_lora.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 16adfe4b83fc..bfbf3603e8d0 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -31,7 +31,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import create_repo, model_info, upload_folder +from huggingface_hub import create_repo, upload_folder from packaging import version from PIL import Image from torch.utils.data import Dataset @@ -589,16 +589,6 @@ def __getitem__(self, index): return example -def model_has_vae(args): - config_file_name = os.path.join("vae", AutoencoderKL.config_name) - if os.path.isdir(args.pretrained_model_name_or_path): - config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name) - return os.path.isfile(config_file_name) - else: - files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings - return any(file.rfilename == config_file_name for file in files_in_repo) - - def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): if tokenizer_max_length is not None: max_length = tokenizer_max_length @@ -753,11 +743,13 @@ def main(args): text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) - if model_has_vae(args): + try: vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision ) - else: + except OSError: + # IF does not have a VAE so let's just set it to None + # We don't have to error out here vae = None unet = UNet2DConditionModel.from_pretrained( From bd78f63a54e439a46f162f191618e3ba554aeef6 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 17 May 2023 15:54:59 +0530 Subject: [PATCH 14/33] Reduce peak VRAM by releasing large attention tensors (as soon as they're unnecessary) (#3463) Release large tensors in attention (as soon as they're no longer required). Reduces peak VRAM by nearly 2 GB for 1024x1024 (even after slicing), and the savings scale up with image size. --- src/diffusers/models/attention_processor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index f88400da0333..a489814c4787 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -344,11 +344,14 @@ def get_attention_scores(self, query, key, attention_mask=None): beta=beta, alpha=self.scale, ) + del baddbmm_input if self.upcast_softmax: attention_scores = attention_scores.float() attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + attention_probs = attention_probs.to(dtype) return attention_probs From 2faf91dbdeb51ad41e8a398d16818932374cde0c Mon Sep 17 00:00:00 2001 From: wfng92 <43742196+wfng92@users.noreply.github.com> Date: Wed, 17 May 2023 19:07:45 +0800 Subject: [PATCH 15/33] Add min snr to text2img lora training script (#3459) add min snr to text2img lora training script --- .../text_to_image/train_text_to_image_lora.py | 49 ++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index c2a4e1aacdb7..806637f04c53 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -239,6 +239,13 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) @@ -472,6 +479,30 @@ def main(): else: raise ValueError("xformers is not available. Make sure it is installed correctly") + def compute_snr(timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + lora_layers = AttnProcsLayers(unet.attn_processors) # Enable TF32 for faster training on Ampere GPUs, @@ -727,7 +758,23 @@ def collate_fn(examples): # Predict the noise residual and compute loss model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(timesteps) + mse_loss_weights = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + # We first calculate the original loss. Then we mean over the non-batch dimensions and + # rebalance the sample-wise losses with their respective loss weights. + # Finally, we take the mean of the rebalanced loss. + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() # Gather the losses across all processes for logging (if we use distributed training). avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() From 88295f92d963f414cc7adf93f30c694a4d100dd2 Mon Sep 17 00:00:00 2001 From: Glaceon-Hyy Date: Wed, 17 May 2023 19:28:19 +0800 Subject: [PATCH 16/33] Add inpaint lora scale support (#3460) * add inpaint lora scale support * add inpaint lora scale test --------- Co-authored-by: yueyang.hyy --- .../pipeline_stable_diffusion_inpaint.py | 18 +++++++--- .../test_stable_diffusion_inpaint.py | 35 +++++++++++++++++++ 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 78ef11587b4d..f09db016d956 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -14,7 +14,7 @@ import inspect import warnings -from typing import Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL @@ -744,6 +744,7 @@ def __call__( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -815,7 +816,10 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. - + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: ```py @@ -966,9 +970,13 @@ def __call__( latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[ - 0 - ] + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 5c5e4c4590dc..5c2d9d7c44f7 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -35,6 +35,7 @@ from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu +from ...models.test_models_unet_2d_condition import create_lora_layers from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin @@ -155,6 +156,40 @@ def test_stable_diffusion_inpaint_image_tensor(self): assert out_pil.shape == (1, 64, 64, 3) assert np.abs(out_pil.flatten() - out_tensor.flatten()).max() < 5e-2 + def test_stable_diffusion_inpaint_lora(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components() + sd_pipe = StableDiffusionInpaintPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # forward 1 + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs) + image = output.images + image_slice = image[0, -3:, -3:, -1] + + # set lora layers + lora_attn_procs = create_lora_layers(sd_pipe.unet) + sd_pipe.unet.set_attn_processor(lora_attn_procs) + sd_pipe = sd_pipe.to(torch_device) + + # forward 2 + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.0}) + image = output.images + image_slice_1 = image[0, -3:, -3:, -1] + + # forward 3 + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.5}) + image = output.images + image_slice_2 = image[0, -3:, -3:, -1] + + assert np.abs(image_slice - image_slice_1).max() < 1e-2 + assert np.abs(image_slice - image_slice_2).max() > 1e-2 + def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) From 2858d7e15eaf445ec37fc77b204a85f84affbeef Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 17 May 2023 14:26:53 +0200 Subject: [PATCH 17/33] [From ckpt] Fix from_ckpt (#3466) * Correct from_ckpt * make style --- src/diffusers/loaders.py | 2 +- .../stable_diffusion/convert_from_ckpt.py | 22 +++++++++++-------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index a1f0d8ec2a52..e50bc31a5c63 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1326,7 +1326,7 @@ def from_ckpt(cls, pretrained_model_link_or_path, **kwargs): file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] from_safetensors = file_extension == "safetensors" - if from_safetensors and use_safetensors is True: + if from_safetensors and use_safetensors is False: raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") # TODO: For now we only support stable diffusion diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 5961636dd197..42e8ae7cafd2 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -140,17 +140,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.bias", "group_norm.bias") - new_item = new_item.replace("q.weight", "query.weight") - new_item = new_item.replace("q.bias", "query.bias") + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") - new_item = new_item.replace("k.weight", "key.weight") - new_item = new_item.replace("k.bias", "key.bias") + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") - new_item = new_item.replace("v.weight", "value.weight") - new_item = new_item.replace("v.bias", "value.bias") + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") - new_item = new_item.replace("proj_out.weight", "proj_attn.weight") - new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) @@ -204,8 +204,12 @@ def assign_to_checkpoint( new_path = new_path.replace(replacement["old"], replacement["new"]) # proj_attn.weight has to be converted from conv 1D to linear - if "proj_attn.weight" in new_path: + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] else: checkpoint[new_path] = old_checkpoint[path["old"]] From c9f939bf9885de32ada828809410b4a6c1d9ff2a Mon Sep 17 00:00:00 2001 From: Will Berman Date: Wed, 17 May 2023 10:42:20 -0700 Subject: [PATCH 18/33] Update full dreambooth script to work with IF (#3425) --- examples/dreambooth/train_dreambooth.py | 306 ++++++++++++++++++++---- examples/test_examples.py | 26 ++ src/diffusers/models/unet_2d_blocks.py | 69 ++++-- 3 files changed, 344 insertions(+), 57 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 5d2107f024d1..efcfb39ab4c4 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and import argparse +import gc import hashlib import itertools import logging @@ -30,7 +31,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import create_repo, upload_folder +from huggingface_hub import create_repo, model_info, upload_folder from packaging import version from PIL import Image from torch.utils.data import Dataset @@ -93,31 +94,61 @@ def save_model_card(repo_id: str, images=None, base_model=str, train_text_encode f.write(yaml + model_card) -def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): +def log_validation( + text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch, prompt_embeds, negative_prompt_embeds +): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) + + pipeline_args = {} + + if text_encoder is not None: + pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder) + + if vae is not None: + pipeline_args["vae"] = vae + # create pipeline (note: unet and vae are loaded again in float32) pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - text_encoder=accelerator.unwrap_model(text_encoder), tokenizer=tokenizer, unet=accelerator.unwrap_model(unet), - vae=vae, revision=args.revision, torch_dtype=weight_dtype, + **pipeline_args, ) - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) + if args.pre_compute_text_embeddings: + pipeline_args = { + "prompt_embeds": prompt_embeds, + "negative_prompt_embeds": negative_prompt_embeds, + } + else: + pipeline_args = {"prompt": args.validation_prompt} + # run inference generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] for _ in range(args.num_validation_images): with torch.autocast("cuda"): - image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0] images.append(image) for tracker in accelerator.trackers: @@ -155,6 +186,10 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation return RobertaSeriesModelWithTransformation + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel else: raise ValueError(f"{model_class} is not supported.") @@ -459,6 +494,27 @@ def parse_args(input_args=None): " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information." ), ) + parser.add_argument( + "--pre_compute_text_embeddings", + action="store_true", + help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.", + ) + parser.add_argument( + "--tokenizer_max_length", + type=int, + default=None, + required=False, + help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", + ) + parser.add_argument( + "--text_encoder_use_attention_mask", + action="store_true", + required=False, + help="Whether to use attention mask for the text encoder", + ) + parser.add_argument( + "--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder" + ) if input_args is not None: args = parser.parse_args(input_args) @@ -481,6 +537,9 @@ def parse_args(input_args=None): if args.class_prompt is not None: warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + if args.train_text_encoder and args.pre_compute_text_embeddings: + raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") + return args @@ -500,10 +559,16 @@ def __init__( class_num=None, size=512, center_crop=False, + encoder_hidden_states=None, + instance_prompt_encoder_hidden_states=None, + tokenizer_max_length=None, ): self.size = size self.center_crop = center_crop self.tokenizer = tokenizer + self.encoder_hidden_states = encoder_hidden_states + self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states + self.tokenizer_max_length = tokenizer_max_length self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): @@ -545,40 +610,52 @@ def __getitem__(self, index): if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) - example["instance_prompt_ids"] = self.tokenizer( - self.instance_prompt, - truncation=True, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ).input_ids + + if self.encoder_hidden_states is not None: + example["instance_prompt_ids"] = self.encoder_hidden_states + else: + text_inputs = tokenize_prompt( + self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["instance_prompt_ids"] = text_inputs.input_ids + example["instance_attention_mask"] = text_inputs.attention_mask if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) - example["class_prompt_ids"] = self.tokenizer( - self.class_prompt, - truncation=True, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ).input_ids + + if self.instance_prompt_encoder_hidden_states is not None: + example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states + else: + class_text_inputs = tokenize_prompt( + self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["class_prompt_ids"] = class_text_inputs.input_ids + example["class_attention_mask"] = class_text_inputs.attention_mask return example def collate_fn(examples, with_prior_preservation=False): + has_attention_mask = "instance_attention_mask" in examples[0] + input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] + if has_attention_mask: + attention_mask = [example["instance_attention_mask"] for example in examples] + # Concat class and instance examples for prior preservation. # We do this to avoid doing two forward passes. if with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] + if has_attention_mask: + attention_mask += [example["class_attention_mask"] for example in examples] + pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() @@ -588,6 +665,10 @@ def collate_fn(examples, with_prior_preservation=False): "input_ids": input_ids, "pixel_values": pixel_values, } + + if has_attention_mask: + batch["attention_mask"] = attention_mask + return batch @@ -608,6 +689,50 @@ def __getitem__(self, index): return example +def model_has_vae(args): + config_file_name = os.path.join("vae", AutoencoderKL.config_name) + if os.path.isdir(args.pretrained_model_name_or_path): + config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name) + return os.path.isfile(config_file_name) + else: + files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings + return any(file.rfilename == config_file_name for file in files_in_repo) + + +def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): + if tokenizer_max_length is not None: + max_length = tokenizer_max_length + else: + max_length = tokenizer.model_max_length + + text_inputs = tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + + return text_inputs + + +def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): + text_input_ids = input_ids.to(text_encoder.device) + + if text_encoder_use_attention_mask: + attention_mask = attention_mask.to(text_encoder.device) + else: + attention_mask = None + + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + return prompt_embeds + + def main(args): logging_dir = Path(args.output_dir, args.logging_dir) @@ -727,7 +852,14 @@ def main(args): text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + + if model_has_vae(args): + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + ) + else: + vae = None + unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) @@ -761,7 +893,9 @@ def load_model_hook(models, input_dir): accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) - vae.requires_grad_(False) + if vae is not None: + vae.requires_grad_(False) + if not args.train_text_encoder: text_encoder.requires_grad_(False) @@ -835,6 +969,44 @@ def load_model_hook(models, input_dir): eps=args.adam_epsilon, ) + if args.pre_compute_text_embeddings: + + def compute_text_embeddings(prompt): + with torch.no_grad(): + text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length) + prompt_embeds = encode_prompt( + text_encoder, + text_inputs.input_ids, + text_inputs.attention_mask, + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + return prompt_embeds + + pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) + validation_prompt_negative_prompt_embeds = compute_text_embeddings("") + + if args.validation_prompt is not None: + validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt) + else: + validation_prompt_encoder_hidden_states = None + + if args.instance_prompt is not None: + pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) + else: + pre_computed_instance_prompt_encoder_hidden_states = None + + text_encoder = None + tokenizer = None + + gc.collect() + torch.cuda.empty_cache() + else: + pre_computed_encoder_hidden_states = None + validation_prompt_encoder_hidden_states = None + validation_prompt_negative_prompt_embeds = None + pre_computed_instance_prompt_encoder_hidden_states = None + # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, @@ -845,6 +1017,9 @@ def load_model_hook(models, input_dir): tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop, + encoder_hidden_states=pre_computed_encoder_hidden_states, + instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states, + tokenizer_max_length=args.tokenizer_max_length, ) train_dataloader = torch.utils.data.DataLoader( @@ -890,8 +1065,10 @@ def load_model_hook(models, input_dir): weight_dtype = torch.bfloat16 # Move vae and text_encoder to device and cast to weight_dtype - vae.to(accelerator.device, dtype=weight_dtype) - if not args.train_text_encoder: + if vae is not None: + vae.to(accelerator.device, dtype=weight_dtype) + + if not args.train_text_encoder and text_encoder is not None: text_encoder.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. @@ -961,37 +1138,55 @@ def load_model_hook(models, input_dir): continue with accelerator.accumulate(unet): - # Convert images to latent space - latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * vae.config.scaling_factor + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) - # Sample noise that we'll add to the latents + if vae is not None: + # Convert images to latent space + model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + else: + model_input = pixel_values + + # Sample noise that we'll add to the model input if args.offset_noise: - noise = torch.randn_like(latents) + 0.1 * torch.randn( - latents.shape[0], latents.shape[1], 1, 1, device=latents.device + noise = torch.randn_like(model_input) + 0.1 * torch.randn( + model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device ) else: - noise = torch.randn_like(latents) - bsz = latents.shape[0] + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep + # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + if args.pre_compute_text_embeddings: + encoder_hidden_states = batch["input_ids"] + else: + encoder_hidden_states = encode_prompt( + text_encoder, + batch["input_ids"], + batch["attention_mask"], + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample + + if model_pred.shape[1] == 6: + model_pred, _ = torch.chunk(model_pred, 2, dim=1) # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) + target = noise_scheduler.get_velocity(model_input, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") @@ -1037,7 +1232,16 @@ def load_model_hook(models, input_dir): if args.validation_prompt is not None and global_step % args.validation_steps == 0: images = log_validation( - text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch + text_encoder, + tokenizer, + unet, + vae, + args, + accelerator, + weight_dtype, + epoch, + validation_prompt_encoder_hidden_states, + validation_prompt_negative_prompt_embeds, ) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} @@ -1050,12 +1254,34 @@ def load_model_hook(models, input_dir): # Create the pipeline using using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: + pipeline_args = {} + + if text_encoder is not None: + pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder) + + if args.skip_save_text_encoder: + pipeline_args["text_encoder"] = None + pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), revision=args.revision, + **pipeline_args, ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args) + pipeline.save_pretrained(args.output_dir) if args.push_to_hub: diff --git a/examples/test_examples.py b/examples/test_examples.py index d9e7de717f47..59c96f44fe93 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -147,6 +147,32 @@ def test_dreambooth(self): self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + def test_dreambooth_if(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth.py + --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --pre_compute_text_embeddings + --tokenizer_max_length=77 + --text_encoder_use_attention_mask + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + def test_dreambooth_checkpointing(self): instance_prompt = "photo" pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 7b76dd7e37bd..75d9eb3e03df 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -1507,16 +1507,33 @@ def forward( cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} for resnet, attn in zip(self.resnets, self.attentions): - # resnet - hidden_states = resnet(hidden_states, temb) + if self.training and self.gradient_checkpointing: - # attn - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) output_states = output_states + (hidden_states,) @@ -2593,15 +2610,33 @@ def forward( res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) + if self.training and self.gradient_checkpointing: - # attn - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) if self.upsamplers is not None: for upsampler in self.upsamplers: From 7200985eab7126801fffcf8251fd149c1cf1f291 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Wed, 17 May 2023 11:56:10 -0700 Subject: [PATCH 19/33] Add IF dreambooth docs (#3470) --- examples/dreambooth/README.md | 64 +++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 75d705f89e02..086100bd4a36 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -531,3 +531,67 @@ More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_ ### Experimental results You can refer to [this blog post](https://huggingface.co/blog/dreambooth) that discusses some of DreamBooth experiments in detail. Specifically, it recommends a set of DreamBooth-specific tips and tricks that we have found to work well for a variety of subjects. + +## IF + +You can use the lora and full dreambooth scripts to also train the text to image [IF model](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0). A few alternative cli flags are needed due to the model size, the expected input resolution, and the text encoder conventions. + +### LoRA Dreambooth +This training configuration requires ~28 GB VRAM. + +```sh +export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="dreambooth_dog_lora" + +accelerate launch train_dreambooth_lora.py \ + --report_to wandb \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a sks dog" \ + --resolution=64 \ # The input resolution of the IF unet is 64x64 + --train_batch_size=4 \ + --gradient_accumulation_steps=1 \ + --learning_rate=5e-6 \ + --scale_lr \ + --max_train_steps=1200 \ + --validation_prompt="a sks dog" \ + --validation_epochs=25 \ + --checkpointing_steps=100 \ + --pre_compute_text_embeddings \ # Pre compute text embeddings to that T5 doesn't have to be kept in memory + --tokenizer_max_length=77 \ # IF expects an override of the max token length + --text_encoder_use_attention_mask # IF expects attention mask for text embeddings +``` + +### Full Dreambooth +Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam. +Using 8bit adam and the rest of the following config, the model can be trained in ~48 GB VRAM. + +For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. + +```sh +export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0" + +export INSTANCE_DIR="dog" +export OUTPUT_DIR="dreambooth_if" + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a photo of sks dog" \ + --resolution=64 \ # The input resolution of the IF unet is 64x64 + --train_batch_size=4 \ + --gradient_accumulation_steps=1 \ + --learning_rate=1e-7 \ + --max_train_steps=150 \ + --validation_prompt "a photo of sks dog" \ + --validation_steps 25 \ + --text_encoder_use_attention_mask \ # IF expects attention mask for text embeddings + --tokenizer_max_length 77 \ # IF expects an override of the max token length + --pre_compute_text_embeddings \ # Pre compute text embeddings to that T5 doesn't have to be kept in memory + --use_8bit_adam \ # + --set_grads_to_none \ + --skip_save_text_encoder # do not save the full T5 text encoder with the model +``` From 49b7ccfb965ce77046477f292b8e9f9777bea0e9 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Thu, 18 May 2023 10:14:29 -0700 Subject: [PATCH 20/33] parameterize pass single args through tuple (#3477) --- tests/models/test_models_vae.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index fd4cf0114f51..9a3e49cdfbc0 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -321,7 +321,7 @@ def test_stable_diffusion_decode_fp16(self, seed, expected_slice): assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) - @parameterized.expand([13, 16, 27]) + @parameterized.expand([(13,), (16,), (27,)]) @require_torch_gpu @unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.") def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed): @@ -339,7 +339,7 @@ def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed): assert torch_all_close(sample, sample_2, atol=1e-1) - @parameterized.expand([13, 16, 37]) + @parameterized.expand([(13,), (16,), (37,)]) @require_torch_gpu @unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.") def test_stable_diffusion_decode_xformers_vs_2_0(self, seed): From 8917769499632c5539f81e9bae9e923825e5be69 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Thu, 18 May 2023 10:24:49 -0700 Subject: [PATCH 21/33] attend and excite tests disable determinism on the class level (#3478) --- ...test_stable_diffusion_attend_and_excite.py | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py index 898d5741043f..6cec2cce752d 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py @@ -34,7 +34,6 @@ torch.backends.cuda.matmul.allow_tf32 = False -torch.use_deterministic_algorithms(False) @skip_mps @@ -47,6 +46,19 @@ class StableDiffusionAttendAndExcitePipelineFastTests( batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"token_indices"}) image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + # Attend and excite requires being able to run a backward pass at + # inference time. There's no deterministic backward operator for pad + + @classmethod + def setUpClass(cls): + super().setUpClass() + torch.use_deterministic_algorithms(False) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + torch.use_deterministic_algorithms(True) + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( @@ -171,6 +183,19 @@ def test_save_load_optional_components(self): @require_torch_gpu @slow class StableDiffusionAttendAndExcitePipelineIntegrationTests(unittest.TestCase): + # Attend and excite requires being able to run a backward pass at + # inference time. There's no deterministic backward operator for pad + + @classmethod + def setUpClass(cls): + super().setUpClass() + torch.use_deterministic_algorithms(False) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + torch.use_deterministic_algorithms(True) + def tearDown(self): super().tearDown() gc.collect() From 8d646f229440999f8c20bf8cbaf016dc4b35441d Mon Sep 17 00:00:00 2001 From: Will Berman Date: Thu, 18 May 2023 19:10:14 -0700 Subject: [PATCH 22/33] dreambooth docs torch.compile note (#3471) * dreambooth docs torch.compile note * Update examples/dreambooth/README.md Co-authored-by: Sayak Paul * Update examples/dreambooth/README.md Co-authored-by: Pedro Cuenca --------- Co-authored-by: Sayak Paul Co-authored-by: Pedro Cuenca --- examples/dreambooth/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 086100bd4a36..83073210ac04 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -43,6 +43,8 @@ from accelerate.utils import write_basic_config write_basic_config() ``` +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. + ### Dog toy example Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example. From e343443565d9dbbba026f563c35f0d4a0515a8d9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 19 May 2023 07:47:28 +0530 Subject: [PATCH 23/33] add: if entry in the dreambooth training docs. (#3472) --- docs/source/en/training/dreambooth.mdx | 64 ++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/docs/source/en/training/dreambooth.mdx b/docs/source/en/training/dreambooth.mdx index 38a3adf9c4f1..de93772abedd 100644 --- a/docs/source/en/training/dreambooth.mdx +++ b/docs/source/en/training/dreambooth.mdx @@ -496,3 +496,67 @@ image.save("dog-bucket.png") ``` You may also run inference from any of the [saved training checkpoints](#inference-from-a-saved-checkpoint). + +## IF + +You can use the lora and full dreambooth scripts to also train the text to image [IF model](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0). A few alternative cli flags are needed due to the model size, the expected input resolution, and the text encoder conventions. + +### LoRA Dreambooth +This training configuration requires ~28 GB VRAM. + +```sh +export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="dreambooth_dog_lora" + +accelerate launch train_dreambooth_lora.py \ + --report_to wandb \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a sks dog" \ + --resolution=64 \ # The input resolution of the IF unet is 64x64 + --train_batch_size=4 \ + --gradient_accumulation_steps=1 \ + --learning_rate=5e-6 \ + --scale_lr \ + --max_train_steps=1200 \ + --validation_prompt="a sks dog" \ + --validation_epochs=25 \ + --checkpointing_steps=100 \ + --pre_compute_text_embeddings \ # Pre compute text embeddings to that T5 doesn't have to be kept in memory + --tokenizer_max_length=77 \ # IF expects an override of the max token length + --text_encoder_use_attention_mask # IF expects attention mask for text embeddings +``` + +### Full Dreambooth +Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam. +Using 8bit adam and the rest of the following config, the model can be trained in ~48 GB VRAM. + +For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. + +```sh +export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0" + +export INSTANCE_DIR="dog" +export OUTPUT_DIR="dreambooth_if" + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a photo of sks dog" \ + --resolution=64 \ # The input resolution of the IF unet is 64x64 + --train_batch_size=4 \ + --gradient_accumulation_steps=1 \ + --learning_rate=1e-7 \ + --max_train_steps=150 \ + --validation_prompt "a photo of sks dog" \ + --validation_steps 25 \ + --text_encoder_use_attention_mask \ # IF expects attention mask for text embeddings + --tokenizer_max_length 77 \ # IF expects an override of the max token length + --pre_compute_text_embeddings \ # Pre compute text embeddings to that T5 doesn't have to be kept in memory + --use_8bit_adam \ # + --set_grads_to_none \ + --skip_save_text_encoder # do not save the full T5 text encoder with the model +``` \ No newline at end of file From 00c76f6ff19a9667594597c37b4e3da15e9a56db Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Fri, 19 May 2023 09:47:27 -0700 Subject: [PATCH 24/33] [docs] Textual inversion inference (#3473) * add textual inversion inference to docs * add to toctree --------- Co-authored-by: Sayak Paul --- docs/source/en/_toctree.yml | 2 + .../textual_inversion_inference.mdx | 80 +++++++++++++++++++ 2 files changed, 82 insertions(+) create mode 100644 docs/source/en/using-diffusers/textual_inversion_inference.mdx diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 645cbb04c1d0..926a3ea716e8 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -44,6 +44,8 @@ title: Text-guided image-inpainting - local: using-diffusers/depth2img title: Text-guided depth-to-image + - local: using-diffusers/textual_inversion_inference + title: Textual inversion - local: using-diffusers/reusing_seeds title: Improve image quality with deterministic generation - local: using-diffusers/reproducibility diff --git a/docs/source/en/using-diffusers/textual_inversion_inference.mdx b/docs/source/en/using-diffusers/textual_inversion_inference.mdx new file mode 100644 index 000000000000..9eca3e7e465c --- /dev/null +++ b/docs/source/en/using-diffusers/textual_inversion_inference.mdx @@ -0,0 +1,80 @@ +# Textual inversion + +[[open-in-colab]] + +The [`StableDiffusionPipeline`] supports textual inversion, a technique that enables a model like Stable Diffusion to learn a new concept from just a few sample images. This gives you more control over the generated images and allows you to tailor the model towards specific concepts. You can get started quickly with a collection of community created concepts in the [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer). + +This guide will show you how to run inference with textual inversion using a pre-learned concept from the Stable Diffusion Conceptualizer. If you're interested in teaching a model new concepts with textual inversion, take a look at the [Textual Inversion](./training/text_inversion) training guide. + +Login to your Hugging Face account: + +```py +from huggingface_hub import notebook_login + +notebook_login() +``` + +Import the necessary libraries, and create a helper function to visualize the generated images: + +```py +import os +import torch + +import PIL +from PIL import Image + +from diffusers import StableDiffusionPipeline +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + + +def image_grid(imgs, rows, cols): + assert len(imgs) == rows * cols + + w, h = imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + grid_w, grid_h = grid.size + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid +``` + +Pick a Stable Diffusion checkpoint and a pre-learned concept from the [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer): + +```py +pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5" +repo_id_embeds = "sd-concepts-library/cat-toy" +``` + +Now you can load a pipeline, and pass the pre-learned concept to it: + +```py +pipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16).to("cuda") + +pipeline.load_textual_inversion(repo_id_embeds) +``` + +Create a prompt with the pre-learned concept by using the special placeholder token ``, and choose the number of samples and rows of images you'd like to generate: + +```py +prompt = "a grafitti in a favela wall with a on it" + +num_samples = 2 +num_rows = 2 +``` + +Then run the pipeline (feel free to adjust the parameters like `num_inference_steps` and `guidance_scale` to see how they affect image quality), save the generated images and visualize them with the helper function you created at the beginning: + +```py +all_images = [] +for _ in range(num_rows): + images = pipe(prompt, num_images_per_prompt=num_samples, num_inference_steps=50, guidance_scale=7.5).images + all_images.extend(images) + +grid = image_grid(all_images, num_samples, num_rows) +grid +``` + +
+ +
From e589bdb956c9be33fc73e1d4614d8d1c1ad95544 Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Fri, 19 May 2023 10:07:33 -0700 Subject: [PATCH 25/33] [docs] Distributed inference (#3376) * distributed inference * move to inference section * apply feedback * update with split_between_processes * apply feedback --- docs/source/en/_toctree.yml | 2 + .../en/training/distributed_inference.mdx | 91 +++++++++++++++++++ 2 files changed, 93 insertions(+) create mode 100644 docs/source/en/training/distributed_inference.mdx diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 926a3ea716e8..aa2d907da4bd 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -46,6 +46,8 @@ title: Text-guided depth-to-image - local: using-diffusers/textual_inversion_inference title: Textual inversion + - local: training/distributed_inference + title: Distributed inference with multiple GPUs - local: using-diffusers/reusing_seeds title: Improve image quality with deterministic generation - local: using-diffusers/reproducibility diff --git a/docs/source/en/training/distributed_inference.mdx b/docs/source/en/training/distributed_inference.mdx new file mode 100644 index 000000000000..e85b3f11e238 --- /dev/null +++ b/docs/source/en/training/distributed_inference.mdx @@ -0,0 +1,91 @@ +# Distributed inference with multiple GPUs + +On distributed setups, you can run inference across multiple GPUs with 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) or [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html), which is useful for generating with multiple prompts in parallel. + +This guide will show you how to use 🤗 Accelerate and PyTorch Distributed for distributed inference. + +## 🤗 Accelerate + +🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) is a library designed to make it easy to train or run inference across distributed setups. It simplifies the process of setting up the distributed environment, allowing you to focus on your PyTorch code. + +To begin, create a Python file and initialize an [`accelerate.PartialState`] to create a distributed environment; your setup is automatically detected so you don't need to explicitly define the `rank` or `world_size`. Move the [`DiffusionPipeline`] to `distributed_state.device` to assign a GPU to each process. + +Now use the [`~accelerate.PartialState.split_between_processes`] utility as a context manager to automatically distribute the prompts between the number of processes. + +```py +from accelerate import PartialState +from diffusers import DiffusionPipeline + +pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) +distributed_state = PartialState() +pipeline.to(distributed_state.device) + +with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt: + result = pipeline(prompt).images[0] + result.save(f"result_{distributed_state.process_index}.png") +``` + +Use the `--num_processes` argument to specify the number of GPUs to use, and call `accelerate launch` to run the script: + +```bash +accelerate launch run_distributed.py --num_processes=2 +``` + + + +To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide. + + + +## PyTorch Distributed + +PyTorch supports [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) which enables data parallelism. + +To start, create a Python file and import `torch.distributed` and `torch.multiprocessing` to set up the distributed process group and to spawn the processes for inference on each GPU. You should also initialize a [`DiffusionPipeline`]: + +```py +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from diffusers import DiffusionPipeline + +sd = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) +``` + +You'll want to create a function to run inference; [`init_process_group`](https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group) handles creating a distributed environment with the type of backend to use, the `rank` of the current process, and the `world_size` or the number of processes participating. If you're running inference in parallel over 2 GPUs, then the `world_size` is 2. + +Move the [`DiffusionPipeline`] to `rank` and use `get_rank` to assign a GPU to each process, where each process handles a different prompt: + +```py +def run_inference(rank, world_size): + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + sd.to(rank) + + if torch.distributed.get_rank() == 0: + prompt = "a dog" + elif torch.distributed.get_rank() == 1: + prompt = "a cat" + + image = sd(prompt).images[0] + image.save(f"./{'_'.join(prompt)}.png") +``` + +To run the distributed inference, call [`mp.spawn`](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn) to run the `run_inference` function on the number of GPUs defined in `world_size`: + +```py +def main(): + world_size = 2 + mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=True) + + +if __name__ == "__main__": + main() +``` + +Once you've completed the inference script, use the `--nproc_per_node` argument to specify the number of GPUs to use and call `torchrun` to run the script: + +```bash +torchrun run_distributed.py --nproc_per_node=2 +``` \ No newline at end of file From 85eff637aad1106f593d7535ec41cdb736b0b2ea Mon Sep 17 00:00:00 2001 From: Will Berman Date: Fri, 19 May 2023 10:45:56 -0700 Subject: [PATCH 26/33] [{Up,Down}sample1d] explicit view kernel size as number elements in flattened indices (#3479) explicit view kernel size as number elements in flattened indices --- src/diffusers/models/unet_1d_blocks.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index a0f0e58f9103..934a4a4a7dcb 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -300,7 +300,8 @@ def forward(self, hidden_states): hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) - weight[indices, indices] = self.kernel.to(weight) + kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) + weight[indices, indices] = kernel return F.conv1d(hidden_states, weight, stride=2) @@ -316,7 +317,8 @@ def forward(self, hidden_states, temb=None): hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) - weight[indices, indices] = self.kernel.to(weight) + kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) + weight[indices, indices] = kernel return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1) From f7b4f51cc2a423c96cb2a4c2282e55feba0be506 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Sat, 20 May 2023 13:43:07 +0200 Subject: [PATCH 27/33] mps & onnx tests rework (#3449) * Remove ONNX tests from PR. They are already a part of push_tests.yml. * Remove mps tests from PRs. They are already performed on push. * Fix workflow name for fast push tests. * Extract mps tests to a workflow. For better control/filtering. * Remove --extra-index-url from mps tests * Increase tolerance of mps test This test passes in my Mac (Ventura 13.3) but fails in the CI hardware (Ventura 13.2). I ran the local tests following the same steps that exist in the CI workflow. * Temporarily run mps tests on pr So we can test. * Revert "Temporarily run mps tests on pr" Tests passed, go back to running on push. --- .github/workflows/pr_tests.yml | 66 ------------------------- .github/workflows/push_tests_fast.yml | 55 +-------------------- .github/workflows/push_tests_mps.yml | 68 ++++++++++++++++++++++++++ tests/schedulers/test_scheduler_lms.py | 2 +- 4 files changed, 70 insertions(+), 121 deletions(-) create mode 100644 .github/workflows/push_tests_mps.yml diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index 23a7659166c0..162b1ba83d66 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -36,11 +36,6 @@ jobs: runner: docker-cpu image: diffusers/diffusers-flax-cpu report: flax_cpu - - name: Fast ONNXRuntime CPU tests - framework: onnxruntime - runner: docker-cpu - image: diffusers/diffusers-onnxruntime-cpu - report: onnx_cpu - name: PyTorch Example CPU tests framework: pytorch_examples runner: docker-cpu @@ -98,14 +93,6 @@ jobs: --make-reports=tests_${{ matrix.config.report }} \ tests - - name: Run fast ONNXRuntime CPU tests - if: ${{ matrix.config.framework == 'onnxruntime' }} - run: | - python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "Onnx" \ - --make-reports=tests_${{ matrix.config.report }} \ - tests/ - - name: Run example PyTorch CPU tests if: ${{ matrix.config.framework == 'pytorch_examples' }} run: | @@ -123,56 +110,3 @@ jobs: with: name: pr_${{ matrix.config.report }}_test_reports path: reports - - run_fast_tests_apple_m1: - name: Fast PyTorch MPS tests on MacOS - runs-on: [ self-hosted, apple-m1 ] - - steps: - - name: Checkout diffusers - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - name: Clean checkout - shell: arch -arch arm64 bash {0} - run: | - git clean -fxd - - - name: Setup miniconda - uses: ./.github/actions/setup-miniconda - with: - python-version: 3.9 - - - name: Install dependencies - shell: arch -arch arm64 bash {0} - run: | - ${CONDA_RUN} python -m pip install --upgrade pip - ${CONDA_RUN} python -m pip install -e .[quality,test] - ${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu - ${CONDA_RUN} python -m pip install accelerate --upgrade - ${CONDA_RUN} python -m pip install transformers --upgrade - - - name: Environment - shell: arch -arch arm64 bash {0} - run: | - ${CONDA_RUN} python utils/print_env.py - - - name: Run fast PyTorch tests on M1 (MPS) - shell: arch -arch arm64 bash {0} - env: - HF_HOME: /System/Volumes/Data/mnt/cache - HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - run: | - ${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/ - - - name: Failure short reports - if: ${{ failure() }} - run: cat reports/tests_torch_mps_failures_short.txt - - - name: Test suite reports artifacts - if: ${{ always() }} - uses: actions/upload-artifact@v2 - with: - name: pr_torch_mps_test_reports - path: reports diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml index 50ef729161d3..adf4fc8a87bc 100644 --- a/.github/workflows/push_tests_fast.yml +++ b/.github/workflows/push_tests_fast.yml @@ -1,4 +1,4 @@ -name: Slow tests on main +name: Fast tests on main on: push: @@ -108,56 +108,3 @@ jobs: with: name: pr_${{ matrix.config.report }}_test_reports path: reports - - run_fast_tests_apple_m1: - name: Fast PyTorch MPS tests on MacOS - runs-on: [ self-hosted, apple-m1 ] - - steps: - - name: Checkout diffusers - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - name: Clean checkout - shell: arch -arch arm64 bash {0} - run: | - git clean -fxd - - - name: Setup miniconda - uses: ./.github/actions/setup-miniconda - with: - python-version: 3.9 - - - name: Install dependencies - shell: arch -arch arm64 bash {0} - run: | - ${CONDA_RUN} python -m pip install --upgrade pip - ${CONDA_RUN} python -m pip install -e .[quality,test] - ${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu - ${CONDA_RUN} python -m pip install accelerate --upgrade - ${CONDA_RUN} python -m pip install transformers --upgrade - - - name: Environment - shell: arch -arch arm64 bash {0} - run: | - ${CONDA_RUN} python utils/print_env.py - - - name: Run fast PyTorch tests on M1 (MPS) - shell: arch -arch arm64 bash {0} - env: - HF_HOME: /System/Volumes/Data/mnt/cache - HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - run: | - ${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/ - - - name: Failure short reports - if: ${{ failure() }} - run: cat reports/tests_torch_mps_failures_short.txt - - - name: Test suite reports artifacts - if: ${{ always() }} - uses: actions/upload-artifact@v2 - with: - name: pr_torch_mps_test_reports - path: reports diff --git a/.github/workflows/push_tests_mps.yml b/.github/workflows/push_tests_mps.yml new file mode 100644 index 000000000000..6b95815f1ea5 --- /dev/null +++ b/.github/workflows/push_tests_mps.yml @@ -0,0 +1,68 @@ +name: Fast mps tests on main + +on: + push: + branches: + - main + +env: + DIFFUSERS_IS_CI: yes + HF_HOME: /mnt/cache + OMP_NUM_THREADS: 8 + MKL_NUM_THREADS: 8 + PYTEST_TIMEOUT: 600 + RUN_SLOW: no + +jobs: + run_fast_tests_apple_m1: + name: Fast PyTorch MPS tests on MacOS + runs-on: [ self-hosted, apple-m1 ] + + steps: + - name: Checkout diffusers + uses: actions/checkout@v3 + with: + fetch-depth: 2 + + - name: Clean checkout + shell: arch -arch arm64 bash {0} + run: | + git clean -fxd + + - name: Setup miniconda + uses: ./.github/actions/setup-miniconda + with: + python-version: 3.9 + + - name: Install dependencies + shell: arch -arch arm64 bash {0} + run: | + ${CONDA_RUN} python -m pip install --upgrade pip + ${CONDA_RUN} python -m pip install -e .[quality,test] + ${CONDA_RUN} python -m pip install torch torchvision torchaudio + ${CONDA_RUN} python -m pip install accelerate --upgrade + ${CONDA_RUN} python -m pip install transformers --upgrade + + - name: Environment + shell: arch -arch arm64 bash {0} + run: | + ${CONDA_RUN} python utils/print_env.py + + - name: Run fast PyTorch tests on M1 (MPS) + shell: arch -arch arm64 bash {0} + env: + HF_HOME: /System/Volumes/Data/mnt/cache + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + run: | + ${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/ + + - name: Failure short reports + if: ${{ failure() }} + run: cat reports/tests_torch_mps_failures_short.txt + + - name: Test suite reports artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v2 + with: + name: pr_torch_mps_test_reports + path: reports diff --git a/tests/schedulers/test_scheduler_lms.py b/tests/schedulers/test_scheduler_lms.py index 3f31f9696de2..2682886a788d 100644 --- a/tests/schedulers/test_scheduler_lms.py +++ b/tests/schedulers/test_scheduler_lms.py @@ -136,5 +136,5 @@ def test_full_loop_device_karras_sigmas(self): result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 3812.9927) < 1e-2 + assert abs(result_sum.item() - 3812.9927) < 2e-2 assert abs(result_mean.item() - 4.9648) < 1e-3 From 4bbc51d94d08a0c74cb28a036e120a32b5237b9a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sun, 21 May 2023 15:26:47 +0530 Subject: [PATCH 28/33] [Attention processor] Better warning message when shifting to `AttnProcessor2_0` (#3457) * add: debugging to enabling memory efficient processing * add: better warning message. --- src/diffusers/models/attention_processor.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a489814c4787..86997632cac1 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -191,7 +191,10 @@ def set_use_memory_efficient_attention_xformers( elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk: warnings.warn( "You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. " - "We will default to PyTorch's native efficient flash attention implementation provided by PyTorch 2.0." + "We will default to PyTorch's native efficient flash attention implementation (`F.scaled_dot_product_attention`) " + "introduced in PyTorch 2.0. In case you are using LoRA or Custom Diffusion, we will fall " + "back to their respective attention processors i.e., we will NOT use the PyTorch 2.0 " + "native efficient flash attention." ) else: try: @@ -213,6 +216,9 @@ def set_use_memory_efficient_attention_xformers( ) processor.load_state_dict(self.processor.state_dict()) processor.to(self.processor.to_q_lora.up.weight.device) + print( + f"is_lora is set to {is_lora}, type: LoRAXFormersAttnProcessor: {isinstance(processor, LoRAXFormersAttnProcessor)}" + ) elif is_custom_diffusion: processor = CustomDiffusionXFormersAttnProcessor( train_kv=self.processor.train_kv, @@ -250,6 +256,7 @@ def set_use_memory_efficient_attention_xformers( # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + print("Still defaulting to: AttnProcessor2_0 :O") processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk From 49ad61c2045a3278ea0b6648546c0824e9d89c0f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sun, 21 May 2023 15:26:56 +0530 Subject: [PATCH 29/33] [Docs] add note on local directory path. (#3397) add note on local directory path. Co-authored-by: Patrick von Platen --- docs/source/en/training/lora.mdx | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/source/en/training/lora.mdx b/docs/source/en/training/lora.mdx index 04eff7af11f8..748d99d5020d 100644 --- a/docs/source/en/training/lora.mdx +++ b/docs/source/en/training/lora.mdx @@ -146,6 +146,7 @@ pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch. + ## DreamBooth [DreamBooth](https://arxiv.org/abs/2208.12242) is a finetuning technique for personalizing a text-to-image model like Stable Diffusion to generate photorealistic images of a subject in different contexts, given a few images of the subject. However, DreamBooth is very sensitive to hyperparameters and it is easy to overfit. Some important hyperparameters to consider include those that affect the training time (learning rate, number of training steps), and inference time (number of steps, scheduler type). @@ -268,4 +269,7 @@ Note that the use of [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] is pipe.load_lora_weights(lora_model_path) ``` -* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth). \ No newline at end of file +* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth). + +**Note** that it is possible to provide a local directory path to [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] as well as [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`]. To know about the supported inputs, +refer to the respective docstrings. \ No newline at end of file From 7c1e81f83deac2746454920f469a9fc2250bbe2f Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Mon, 22 May 2023 12:10:30 +0530 Subject: [PATCH 30/33] Add Unet blocks for consistency models --- src/diffusers/models/unet_2d_blocks.py | 220 +++++++++++++++++++++++++ 1 file changed, 220 insertions(+) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 75d9eb3e03df..cad0676f4eab 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -90,6 +90,20 @@ def get_down_block( attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, ) + elif down_block_type == "AttnDownsampleBlock2D": + return AttnDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) elif down_block_type == "CrossAttnDownBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") @@ -314,6 +328,20 @@ def get_up_block( attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, ) + elif up_block_type == "AttnUpsampleBlock2D": + return AttnUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, + ) elif up_block_type == "SkipUpBlock2D": return SkipUpBlock2D( num_layers=num_layers, @@ -762,6 +790,101 @@ def forward(self, hidden_states, temb=None, upsample_size=None): return hidden_states, output_states +class AttnDownsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + down=True + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None, upsample_size=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + + class CrossAttnDownBlock2D(nn.Module): def __init__( self, @@ -1831,6 +1954,103 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si return hidden_states +class AttnUpsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + up=True, + ) + ] + ) + + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + class CrossAttnUpBlock2D(nn.Module): def __init__( self, From 3a151bd1d96736ce5f57fc288075b4df7a08dbf9 Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Mon, 22 May 2023 17:16:36 +0530 Subject: [PATCH 31/33] Add conversion script for Unet --- scripts/convert_consistency_to_diffusers.py | 163 ++++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 scripts/convert_consistency_to_diffusers.py diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py new file mode 100644 index 000000000000..bd269c735dde --- /dev/null +++ b/scripts/convert_consistency_to_diffusers.py @@ -0,0 +1,163 @@ +import argparse +import io +from diffusers.models.unet_2d import UNet2DModel +import requests +import torch + +UNET_CONFIG = { + "sample_size": 64, + "in_channels": 3, + "out_channels": 3, + "layers_per_block" : 3, + "num_class_embeds": 1000, + "block_out_channels": [192, 192*2, 192*3, 192*4], + "attention_head_dim" : 64, + "down_block_types": ["ResnetDownsampleBlock2D", "AttnDownsampleBlock2D", "AttnDownsampleBlock2D", "AttnDownsampleBlock2D"], + "up_block_types": ["AttnUpsampleBlock2D", "AttnUpsampleBlock2D", "AttnUpsampleBlock2D", "ResnetUpsampleBlock2D"], + "resnet_time_scale_shift" : "scale_shift" +} + + +def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=False): + new_checkpoint[f"{new_prefix}.norm1.weight"] = checkpoint[f"{old_prefix}.in_layers.0.weight"] + new_checkpoint[f"{new_prefix}.norm1.bias"] = checkpoint[f"{old_prefix}.in_layers.0.bias"] + new_checkpoint[f"{new_prefix}.conv1.weight"] = checkpoint[f"{old_prefix}.in_layers.2.weight"] + new_checkpoint[f"{new_prefix}.conv1.bias"] = checkpoint[f"{old_prefix}.in_layers.2.bias"] + new_checkpoint[f"{new_prefix}.time_emb_proj.weight"] = checkpoint[f"{old_prefix}.emb_layers.1.weight"] + new_checkpoint[f"{new_prefix}.time_emb_proj.bias"] = checkpoint[f"{old_prefix}.emb_layers.1.bias"] + new_checkpoint[f"{new_prefix}.norm2.weight"] = checkpoint[f"{old_prefix}.out_layers.0.weight"] + new_checkpoint[f"{new_prefix}.norm2.bias"] = checkpoint[f"{old_prefix}.out_layers.0.bias"] + new_checkpoint[f"{new_prefix}.conv2.weight"] = checkpoint[f"{old_prefix}.out_layers.3.weight"] + new_checkpoint[f"{new_prefix}.conv2.bias"] = checkpoint[f"{old_prefix}.out_layers.3.bias"] + + if has_skip: + new_checkpoint[f"{new_prefix}.conv_shortcut.weight"] = checkpoint[f"{old_prefix}.skip_connection.weight"] + new_checkpoint[f"{new_prefix}.conv_shortcut.bias"] = checkpoint[f"{old_prefix}.skip_connection.bias"] + + + return new_checkpoint + +def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix): + weight_q, weight_k, weight_v = checkpoint[f"{old_prefix}.qkv.weight"].chunk(3, dim=0) + bias_q, bias_k, bias_v = checkpoint[f"{old_prefix}.qkv.bias"].chunk(3, dim=0) + + new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"] + new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"] + + new_checkpoint[f"{new_prefix}.to_q.weight"] = weight_q.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_q.bias"] = bias_q.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_k.weight"] = weight_k.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_k.bias"] = bias_k.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1) + + new_checkpoint[f"{new_prefix}.to_out.0.weight"] = checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1) + + return new_checkpoint + + +def con_pt_to_diffuser(checkpoint_path: str, output_path: str): + checkpoint = torch.load(checkpoint_path, map_location="cpu") + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"] + + new_checkpoint["class_embedding.weight"] = checkpoint["label_emb.weight"] + + new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"] + + down_block_types = UNET_CONFIG["down_block_types"] + layers_per_block = UNET_CONFIG["layers_per_block"] + current_layer = 1 + + for (i,layer_type) in enumerate(down_block_types): + + if layer_type == "ResnetDownsampleBlock2D": + for j in range(layers_per_block): + new_prefix = f"down_blocks.{i}.resnets.{j}" + old_prefix = f"input_blocks.{current_layer}.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + current_layer += 1 + + elif layer_type == "AttnDownsampleBlock2D": + for j in range(layers_per_block): + new_prefix = f"down_blocks.{i}.resnets.{j}" + old_prefix = f"input_blocks.{current_layer}.0" + has_skip = True if j == 0 else False + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip) + new_prefix = f"down_blocks.{i}.attentions.{j}" + old_prefix = f"input_blocks.{current_layer}.1" + new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix) + current_layer += 1 + + if i!= len(down_block_types)-1: + new_prefix = f"down_blocks.{i}.downsamplers.0" + old_prefix = f"input_blocks.{current_layer}.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + current_layer += 1 + + # hardcoded the mid-block for now + new_prefix = f"mid_block.resnets.0" + old_prefix = f"middle_block.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + new_prefix = f"mid_block.attentions.0" + old_prefix = f"middle_block.1" + new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix) + new_prefix = f"mid_block.resnets.1" + old_prefix = f"middle_block.2" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + + current_layer = 0 + up_block_types = UNET_CONFIG["up_block_types"] + + for (i, layer_type) in enumerate (up_block_types): + if layer_type == "ResnetUpsampleBlock2D": + for j in range(layers_per_block+1): + new_prefix = f"up_blocks.{i}.resnets.{j}" + old_prefix = f"output_blocks.{current_layer}.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True) + current_layer += 1 + elif layer_type == "AttnUpsampleBlock2D": + for j in range(layers_per_block+1): + new_prefix = f"up_blocks.{i}.resnets.{j}" + old_prefix = f"output_blocks.{current_layer}.0" + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True) + new_prefix = f"up_blocks.{i}.attentions.{j}" + old_prefix = f"output_blocks.{current_layer}.1" + new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix) + current_layer += 1 + + new_prefix = f"up_blocks.{i}.upsamplers.0" + old_prefix = f"output_blocks.{current_layer-1}.2" + # print(new_prefix) + # print(old_prefix) + new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) + + + new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"] + new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"] + new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"] + + return new_checkpoint + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--unet_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.") + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.") + + args = parser.parse_args() + + converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, args.dump_path) + image_unet = UNet2DModel(**UNET_CONFIG) + # print(image_unet) + # exit() + image_unet.load_state_dict(converted_unet_ckpt) + image_unet.save_pretrained(args.dump_path) \ No newline at end of file From b6c5e15bab797e24587fd6fc9d4289d79f7ba43d Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Mon, 22 May 2023 17:54:12 +0530 Subject: [PATCH 32/33] Fix bug in new unet blocks --- src/diffusers/models/unet_2d_blocks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index cad0676f4eab..f94f178f11e7 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -877,7 +877,7 @@ def forward(self, hidden_states, temb=None, upsample_size=None): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states, temb) output_states += (hidden_states,) @@ -2046,7 +2046,7 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, temb) return hidden_states From 4e93f09f2b88c25edfbb4ab6cc57b7a9befd9f51 Mon Sep 17 00:00:00 2001 From: ayushmangal Date: Tue, 23 May 2023 11:31:29 +0530 Subject: [PATCH 33/33] Fix attention weight loading --- scripts/convert_consistency_to_diffusers.py | 37 ++++++++++++++------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py index bd269c735dde..37a89bb86693 100644 --- a/scripts/convert_consistency_to_diffusers.py +++ b/scripts/convert_consistency_to_diffusers.py @@ -37,19 +37,31 @@ def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip= return new_checkpoint -def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix): - weight_q, weight_k, weight_v = checkpoint[f"{old_prefix}.qkv.weight"].chunk(3, dim=0) - bias_q, bias_k, bias_v = checkpoint[f"{old_prefix}.qkv.bias"].chunk(3, dim=0) +def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim=64): + c, _, _, _ = checkpoint[f"{old_prefix}.qkv.weight"].shape + n_heads = c // (attention_head_dim*3) + old_weights = checkpoint[f"{old_prefix}.qkv.weight"].reshape(n_heads, attention_head_dim*3, -1, 1, 1) + old_biases = checkpoint[f"{old_prefix}.qkv.bias"].reshape(n_heads, attention_head_dim*3, -1, 1, 1) + + weight_q, weight_k, weight_v = old_weights.chunk(3, dim=1) + weight_q = weight_q.reshape(n_heads*attention_head_dim, -1, 1, 1) + weight_k = weight_k.reshape(n_heads*attention_head_dim, -1, 1, 1) + weight_v = weight_v.reshape(n_heads*attention_head_dim, -1, 1, 1) + + bias_q, bias_k, bias_v = old_biases.chunk(3, dim=1) + bias_q = bias_q.reshape(n_heads*attention_head_dim, -1, 1, 1) + bias_k = bias_k.reshape(n_heads*attention_head_dim, -1, 1, 1) + bias_v = bias_v.reshape(n_heads*attention_head_dim, -1, 1, 1) new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"] new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"] - new_checkpoint[f"{new_prefix}.to_q.weight"] = weight_q.squeeze(-1).squeeze(-1) - new_checkpoint[f"{new_prefix}.to_q.bias"] = bias_q.squeeze(-1).squeeze(-1) - new_checkpoint[f"{new_prefix}.to_k.weight"] = weight_k.squeeze(-1).squeeze(-1) - new_checkpoint[f"{new_prefix}.to_k.bias"] = bias_k.squeeze(-1).squeeze(-1) - new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1) - new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1) + new_checkpoint[f"{new_prefix}.to_q.weight"] = torch.squeeze(weight_q) + new_checkpoint[f"{new_prefix}.to_q.bias"] = torch.squeeze(bias_q) + new_checkpoint[f"{new_prefix}.to_k.weight"] = torch.squeeze(weight_k) + new_checkpoint[f"{new_prefix}.to_k.bias"] = torch.squeeze(bias_k) + new_checkpoint[f"{new_prefix}.to_v.weight"] = torch.squeeze(weight_v) + new_checkpoint[f"{new_prefix}.to_v.bias"] = torch.squeeze(bias_v) new_checkpoint[f"{new_prefix}.to_out.0.weight"] = checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1) new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1) @@ -73,6 +85,7 @@ def con_pt_to_diffuser(checkpoint_path: str, output_path: str): down_block_types = UNET_CONFIG["down_block_types"] layers_per_block = UNET_CONFIG["layers_per_block"] + attention_head_dim = UNET_CONFIG["attention_head_dim"] current_layer = 1 for (i,layer_type) in enumerate(down_block_types): @@ -92,7 +105,7 @@ def con_pt_to_diffuser(checkpoint_path: str, output_path: str): new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip) new_prefix = f"down_blocks.{i}.attentions.{j}" old_prefix = f"input_blocks.{current_layer}.1" - new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix) + new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim) current_layer += 1 if i!= len(down_block_types)-1: @@ -107,7 +120,7 @@ def con_pt_to_diffuser(checkpoint_path: str, output_path: str): new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) new_prefix = f"mid_block.attentions.0" old_prefix = f"middle_block.1" - new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix) + new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim) new_prefix = f"mid_block.resnets.1" old_prefix = f"middle_block.2" new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) @@ -129,7 +142,7 @@ def con_pt_to_diffuser(checkpoint_path: str, output_path: str): new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True) new_prefix = f"up_blocks.{i}.attentions.{j}" old_prefix = f"output_blocks.{current_layer}.1" - new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix) + new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim) current_layer += 1 new_prefix = f"up_blocks.{i}.upsamplers.0"