Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 123 additions & 61 deletions examples/community/stable_diffusion_controlnet_inpaint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/

import inspect
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import PIL.Image
Expand All @@ -11,6 +11,7 @@

from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
PIL_INTERPOLATION,
Expand Down Expand Up @@ -184,7 +185,14 @@ def prepare_mask_image(mask_image):


def prepare_controlnet_conditioning_image(
controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype
controlnet_conditioning_image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance,
):
if not isinstance(controlnet_conditioning_image, torch.Tensor):
if isinstance(controlnet_conditioning_image, PIL.Image.Image):
Expand Down Expand Up @@ -214,6 +222,9 @@ def prepare_controlnet_conditioning_image(

controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)

if do_classifier_free_guidance:
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)

return controlnet_conditioning_image


Expand All @@ -230,7 +241,7 @@ def __init__(
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
controlnet: ControlNetModel,
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
Expand All @@ -254,6 +265,9 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)

if isinstance(controlnet, (list, tuple)):
controlnet = MultiControlNetModel(controlnet)

self.register_modules(
vae=vae,
text_encoder=text_encoder,
Expand All @@ -264,6 +278,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)

self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

Expand Down Expand Up @@ -522,6 +537,42 @@ def prepare_extra_step_kwargs(self, generator, eta):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs

def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image)
image_is_tensor = isinstance(image, torch.Tensor)
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)

if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
raise TypeError(
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
)

if image_is_pil:
image_batch_size = 1
elif image_is_tensor:
image_batch_size = image.shape[0]
elif image_is_pil_list:
image_batch_size = len(image)
elif image_is_tensor_list:
image_batch_size = len(image)
else:
raise ValueError("controlnet condition image is not valid")

if prompt is not None and isinstance(prompt, str):
prompt_batch_size = 1
elif prompt is not None and isinstance(prompt, list):
prompt_batch_size = len(prompt)
elif prompt_embeds is not None:
prompt_batch_size = prompt_embeds.shape[0]
else:
raise ValueError("prompt or prompt_embeds are not valid")

if image_batch_size != 1 and image_batch_size != prompt_batch_size:
raise ValueError(
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
)

def check_inputs(
self,
prompt,
Expand All @@ -534,6 +585,7 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
controlnet_conditioning_scale=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
Expand Down Expand Up @@ -572,45 +624,35 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)

controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)
controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)
controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(
controlnet_conditioning_image[0], PIL.Image.Image
)
controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(
controlnet_conditioning_image[0], torch.Tensor
)

if (
not controlnet_cond_image_is_pil
and not controlnet_cond_image_is_tensor
and not controlnet_cond_image_is_pil_list
and not controlnet_cond_image_is_tensor_list
):
raise TypeError(
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
)

if controlnet_cond_image_is_pil:
controlnet_cond_image_batch_size = 1
elif controlnet_cond_image_is_tensor:
controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]
elif controlnet_cond_image_is_pil_list:
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
elif controlnet_cond_image_is_tensor_list:
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)

if prompt is not None and isinstance(prompt, str):
prompt_batch_size = 1
elif prompt is not None and isinstance(prompt, list):
prompt_batch_size = len(prompt)
elif prompt_embeds is not None:
prompt_batch_size = prompt_embeds.shape[0]

if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:
raise ValueError(
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"
)
# check controlnet condition image
if isinstance(self.controlnet, ControlNetModel):
self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds)
elif isinstance(self.controlnet, MultiControlNetModel):
if not isinstance(controlnet_conditioning_image, list):
raise TypeError("For multiple controlnets: `image` must be type `list`")
if len(controlnet_conditioning_image) != len(self.controlnet.nets):
raise ValueError(
"For multiple controlnets: `image` must have the same length as the number of controlnets."
)
for image_ in controlnet_conditioning_image:
self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds)
else:
assert False

# Check `controlnet_conditioning_scale`
if isinstance(self.controlnet, ControlNetModel):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif isinstance(self.controlnet, MultiControlNetModel):
if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
self.controlnet.nets
):
raise ValueError(
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
" the same length as the number of controlnets"
)
else:
assert False

if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):
raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor")
Expand All @@ -630,6 +672,8 @@ def check_inputs(
image_channels, image_height, image_width = image.shape
elif image.ndim == 4:
image_batch_size, image_channels, image_height, image_width = image.shape
else:
assert False

if mask_image.ndim == 2:
mask_image_batch_size = 1
Expand Down Expand Up @@ -797,7 +841,7 @@ def __call__(
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: float = 1.0,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -897,6 +941,7 @@ def __call__(
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
controlnet_conditioning_scale,
)

# 2. Define call parameters
Expand All @@ -913,6 +958,9 @@ def __call__(
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0

if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)

# 3. Encode input prompt
prompt_embeds = self._encode_prompt(
prompt,
Expand All @@ -929,15 +977,37 @@ def __call__(

mask_image = prepare_mask_image(mask_image)

controlnet_conditioning_image = prepare_controlnet_conditioning_image(
controlnet_conditioning_image,
width,
height,
batch_size * num_images_per_prompt,
num_images_per_prompt,
device,
self.controlnet.dtype,
)
# condition image(s)
if isinstance(self.controlnet, ControlNetModel):
controlnet_conditioning_image = prepare_controlnet_conditioning_image(
controlnet_conditioning_image=controlnet_conditioning_image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
elif isinstance(self.controlnet, MultiControlNetModel):
controlnet_conditioning_images = []

for image_ in controlnet_conditioning_image:
image_ = prepare_controlnet_conditioning_image(
controlnet_conditioning_image=image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
controlnet_conditioning_images.append(image_)

controlnet_conditioning_image = controlnet_conditioning_images
else:
assert False

masked_image = image * (mask_image < 0.5)

Expand Down Expand Up @@ -979,9 +1049,6 @@ def __call__(
do_classifier_free_guidance,
)

if do_classifier_free_guidance:
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)

# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

Expand All @@ -1007,15 +1074,10 @@ def __call__(
t,
encoder_hidden_states=prompt_embeds,
controlnet_cond=controlnet_conditioning_image,
conditioning_scale=controlnet_conditioning_scale,
return_dict=False,
)

down_block_res_samples = [
down_block_res_sample * controlnet_conditioning_scale
for down_block_res_sample in down_block_res_samples
]
mid_block_res_sample *= controlnet_conditioning_scale

# predict the noise residual
noise_pred = self.unet(
inpainting_latent_model_input,
Expand Down