diff --git a/examples/community/stable_diffusion_controlnet_img2img.py b/examples/community/stable_diffusion_controlnet_img2img.py index 95e5fe7db061..b60346c92580 100644 --- a/examples/community/stable_diffusion_controlnet_img2img.py +++ b/examples/community/stable_diffusion_controlnet_img2img.py @@ -1,7 +1,7 @@ # Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/ import inspect -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image @@ -10,6 +10,7 @@ from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( PIL_INTERPOLATION, @@ -86,7 +87,14 @@ def prepare_image(image): def prepare_controlnet_conditioning_image( - controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype + controlnet_conditioning_image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance, ): if not isinstance(controlnet_conditioning_image, torch.Tensor): if isinstance(controlnet_conditioning_image, PIL.Image.Image): @@ -116,6 +124,9 @@ def prepare_controlnet_conditioning_image( controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype) + if do_classifier_free_guidance: + controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2) + return controlnet_conditioning_image @@ -132,7 +143,7 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - controlnet: ControlNetModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, @@ -156,6 +167,9 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -425,6 +439,42 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs + def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + + if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: + raise TypeError( + "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" + ) + + if image_is_pil: + image_batch_size = 1 + elif image_is_tensor: + image_batch_size = image.shape[0] + elif image_is_pil_list: + image_batch_size = len(image) + elif image_is_tensor_list: + image_batch_size = len(image) + else: + raise ValueError("controlnet condition image is not valid") + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + else: + raise ValueError("prompt or prompt_embeds are not valid") + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + def check_inputs( self, prompt, @@ -439,6 +489,7 @@ def check_inputs( strength=None, controlnet_guidance_start=None, controlnet_guidance_end=None, + controlnet_conditioning_scale=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -477,58 +528,51 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image) - controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor) - controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance( - controlnet_conditioning_image[0], PIL.Image.Image - ) - controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance( - controlnet_conditioning_image[0], torch.Tensor - ) + # check controlnet condition image - if ( - not controlnet_cond_image_is_pil - and not controlnet_cond_image_is_tensor - and not controlnet_cond_image_is_pil_list - and not controlnet_cond_image_is_tensor_list - ): - raise TypeError( - "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" - ) + if isinstance(self.controlnet, ControlNetModel): + self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds) + elif isinstance(self.controlnet, MultiControlNetModel): + if not isinstance(controlnet_conditioning_image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") - if controlnet_cond_image_is_pil: - controlnet_cond_image_batch_size = 1 - elif controlnet_cond_image_is_tensor: - controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0] - elif controlnet_cond_image_is_pil_list: - controlnet_cond_image_batch_size = len(controlnet_conditioning_image) - elif controlnet_cond_image_is_tensor_list: - controlnet_cond_image_batch_size = len(controlnet_conditioning_image) + if len(controlnet_conditioning_image) != len(self.controlnet.nets): + raise ValueError( + "For multiple controlnets: `image` must have the same length as the number of controlnets." + ) - if prompt is not None and isinstance(prompt, str): - prompt_batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - prompt_batch_size = len(prompt) - elif prompt_embeds is not None: - prompt_batch_size = prompt_embeds.shape[0] + for image_ in controlnet_conditioning_image: + self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds) + else: + assert False - if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size: - raise ValueError( - f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}" - ) + # Check `controlnet_conditioning_scale` + + if isinstance(self.controlnet, ControlNetModel): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif isinstance(self.controlnet, MultiControlNetModel): + if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False if isinstance(image, torch.Tensor): if image.ndim != 3 and image.ndim != 4: raise ValueError("`image` must have 3 or 4 dimensions") - # if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4: - # raise ValueError("`mask_image` must have 2, 3, or 4 dimensions") - if image.ndim == 3: image_batch_size = 1 image_channels, image_height, image_width = image.shape elif image.ndim == 4: image_batch_size, image_channels, image_height, image_width = image.shape + else: + assert False if image_channels != 3: raise ValueError("`image` must have 3 channels") @@ -660,7 +704,7 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_conditioning_scale: float = 1.0, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, controlnet_guidance_start: float = 0.0, controlnet_guidance_end: float = 1.0, ): @@ -761,7 +805,6 @@ def __call__( self.check_inputs( prompt, image, - # mask_image, controlnet_conditioning_image, height, width, @@ -772,6 +815,7 @@ def __call__( strength, controlnet_guidance_start, controlnet_guidance_end, + controlnet_conditioning_scale, ) # 2. Define call parameters @@ -788,6 +832,9 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets) + # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, @@ -799,22 +846,41 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) - # 4. Prepare mask, image, and controlnet_conditioning_image + # 4. Prepare image, and controlnet_conditioning_image image = prepare_image(image) - # mask_image = prepare_mask_image(mask_image) + # condition image(s) + if isinstance(self.controlnet, ControlNetModel): + controlnet_conditioning_image = prepare_controlnet_conditioning_image( + controlnet_conditioning_image=controlnet_conditioning_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + elif isinstance(self.controlnet, MultiControlNetModel): + controlnet_conditioning_images = [] + + for image_ in controlnet_conditioning_image: + image_ = prepare_controlnet_conditioning_image( + controlnet_conditioning_image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) - controlnet_conditioning_image = prepare_controlnet_conditioning_image( - controlnet_conditioning_image, - width, - height, - batch_size * num_images_per_prompt, - num_images_per_prompt, - device, - self.controlnet.dtype, - ) + controlnet_conditioning_images.append(image_) - # masked_image = image * (mask_image < 0.5) + controlnet_conditioning_image = controlnet_conditioning_images + else: + assert False # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -832,9 +898,6 @@ def __call__( generator, ) - if do_classifier_free_guidance: - controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2) - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -864,15 +927,10 @@ def __call__( t, encoder_hidden_states=prompt_embeds, controlnet_cond=controlnet_conditioning_image, + conditioning_scale=controlnet_conditioning_scale, return_dict=False, ) - down_block_res_samples = [ - down_block_res_sample * controlnet_conditioning_scale - for down_block_res_sample in down_block_res_samples - ] - mid_block_res_sample *= controlnet_conditioning_scale - # predict the noise residual noise_pred = self.unet( latent_model_input,