From aab478100fce9fea6103f192a21065df11ffa0ed Mon Sep 17 00:00:00 2001 From: root Date: Fri, 14 Apr 2023 14:02:59 +0800 Subject: [PATCH 1/3] try multi controlnet inpaint --- .../stable_diffusion_controlnet_inpaint.py | 140 +++++++++++++++++- 1 file changed, 136 insertions(+), 4 deletions(-) diff --git a/examples/community/stable_diffusion_controlnet_inpaint.py b/examples/community/stable_diffusion_controlnet_inpaint.py index c47f4c3194e8..e72dbcd31889 100644 --- a/examples/community/stable_diffusion_controlnet_inpaint.py +++ b/examples/community/stable_diffusion_controlnet_inpaint.py @@ -1,7 +1,7 @@ # Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/ import inspect -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image @@ -11,6 +11,7 @@ from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( PIL_INTERPOLATION, @@ -184,7 +185,7 @@ def prepare_mask_image(mask_image): def prepare_controlnet_conditioning_image( - controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype + controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance, ): if not isinstance(controlnet_conditioning_image, torch.Tensor): if isinstance(controlnet_conditioning_image, PIL.Image.Image): @@ -214,6 +215,9 @@ def prepare_controlnet_conditioning_image( controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype) + if do_classifier_free_guidance: + controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2) + return controlnet_conditioning_image @@ -230,7 +234,8 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - controlnet: ControlNetModel, + # controlnet: ControlNetModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, @@ -254,6 +259,9 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -264,6 +272,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -522,6 +531,42 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs + def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + + if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: + raise TypeError( + "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" + ) + + if image_is_pil: + image_batch_size = 1 + elif image_is_tensor: + image_batch_size = image.shape[0] + elif image_is_pil_list: + image_batch_size = len(image) + elif image_is_tensor_list: + image_batch_size = len(image) + else: + raise ValueError("controlnet condition image is not valid") + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + else: + raise ValueError("prompt or prompt_embeds are not valid") + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + def check_inputs( self, prompt, @@ -534,6 +579,7 @@ def check_inputs( negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, + controlnet_conditioning_scale=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -572,6 +618,7 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) + """ controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image) controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor) controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance( @@ -580,7 +627,9 @@ def check_inputs( controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance( controlnet_conditioning_image[0], torch.Tensor ) + """ + """ if ( not controlnet_cond_image_is_pil and not controlnet_cond_image_is_tensor @@ -590,7 +639,9 @@ def check_inputs( raise TypeError( "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" ) + """ + """ if controlnet_cond_image_is_pil: controlnet_cond_image_batch_size = 1 elif controlnet_cond_image_is_tensor: @@ -599,18 +650,53 @@ def check_inputs( controlnet_cond_image_batch_size = len(controlnet_conditioning_image) elif controlnet_cond_image_is_tensor_list: controlnet_cond_image_batch_size = len(controlnet_conditioning_image) + """ + # check controlnet condition image + if isinstance(self.controlnet, ControlNetModel): + self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds) + elif isinstance(self.controlnet, MultiControlNetModel): + if not isinstance(controlnet_conditioning_image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + if len(controlnet_conditioning_image) != len(self.controlnet.nets): + raise ValueError( + "For multiple controlnets: `image` must have the same length as the number of controlnets." + ) + for image_ in controlnet_conditioning_image: + self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds) + else: + assert False + + """ if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] + """ + """ if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}" ) + """ + + # Check `controlnet_conditioning_scale` + if isinstance(self.controlnet, ControlNetModel): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif isinstance(self.controlnet, MultiControlNetModel): + if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor): raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor") @@ -630,6 +716,8 @@ def check_inputs( image_channels, image_height, image_width = image.shape elif image.ndim == 4: image_batch_size, image_channels, image_height, image_width = image.shape + else: + assert False if mask_image.ndim == 2: mask_image_batch_size = 1 @@ -797,7 +885,8 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_conditioning_scale: float = 1.0, + # controlnet_conditioning_scale: float = 1.0, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, ): r""" Function invoked when calling the pipeline for generation. @@ -897,6 +986,7 @@ def __call__( negative_prompt, prompt_embeds, negative_prompt_embeds, + controlnet_conditioning_scale, ) # 2. Define call parameters @@ -913,6 +1003,9 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets) + # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, @@ -929,6 +1022,39 @@ def __call__( mask_image = prepare_mask_image(mask_image) + # condition image(s) + if isinstance(self.controlnet, ControlNetModel): + controlnet_conditioning_image = prepare_controlnet_conditioning_image( + controlnet_conditioning_image=controlnet_conditioning_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + elif isinstance(self.controlnet, MultiControlNetModel): + controlnet_conditioning_images = [] + + for image_ in controlnet_conditioning_image: + image_ = prepare_controlnet_conditioning_image( + controlnet_conditioning_image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + controlnet_conditioning_images.append(image_) + + controlnet_conditioning_image = controlnet_conditioning_images + else: + assert False + + """ controlnet_conditioning_image = prepare_controlnet_conditioning_image( controlnet_conditioning_image, width, @@ -938,6 +1064,7 @@ def __call__( device, self.controlnet.dtype, ) + """ masked_image = image * (mask_image < 0.5) @@ -979,8 +1106,10 @@ def __call__( do_classifier_free_guidance, ) + """ if do_classifier_free_guidance: controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2) + """ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -1007,14 +1136,17 @@ def __call__( t, encoder_hidden_states=prompt_embeds, controlnet_cond=controlnet_conditioning_image, + conditioning_scale=controlnet_conditioning_scale, return_dict=False, ) + """ down_block_res_samples = [ down_block_res_sample * controlnet_conditioning_scale for down_block_res_sample in down_block_res_samples ] mid_block_res_sample *= controlnet_conditioning_scale + """ # predict the noise residual noise_pred = self.unet( From 9de6eb6222ceb010517ce127ee89b74606a0b630 Mon Sep 17 00:00:00 2001 From: timegate Date: Mon, 17 Apr 2023 12:10:12 +0800 Subject: [PATCH 2/3] multi controlnet inpaint --- .../community/stable_diffusion_controlnet_inpaint.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/community/stable_diffusion_controlnet_inpaint.py b/examples/community/stable_diffusion_controlnet_inpaint.py index e72dbcd31889..22621e55db84 100644 --- a/examples/community/stable_diffusion_controlnet_inpaint.py +++ b/examples/community/stable_diffusion_controlnet_inpaint.py @@ -185,7 +185,14 @@ def prepare_mask_image(mask_image): def prepare_controlnet_conditioning_image( - controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance, + controlnet_conditioning_image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance, ): if not isinstance(controlnet_conditioning_image, torch.Tensor): if isinstance(controlnet_conditioning_image, PIL.Image.Image): From 50ee696758d69d45b179f9461dedd6b66493c77b Mon Sep 17 00:00:00 2001 From: timegate Date: Mon, 17 Apr 2023 12:15:52 +0800 Subject: [PATCH 3/3] multi controlnet inpaint --- .../stable_diffusion_controlnet_inpaint.py | 77 ------------------- 1 file changed, 77 deletions(-) diff --git a/examples/community/stable_diffusion_controlnet_inpaint.py b/examples/community/stable_diffusion_controlnet_inpaint.py index 22621e55db84..aae199f91b9e 100644 --- a/examples/community/stable_diffusion_controlnet_inpaint.py +++ b/examples/community/stable_diffusion_controlnet_inpaint.py @@ -241,7 +241,6 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - # controlnet: ControlNetModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, @@ -625,40 +624,6 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - """ - controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image) - controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor) - controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance( - controlnet_conditioning_image[0], PIL.Image.Image - ) - controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance( - controlnet_conditioning_image[0], torch.Tensor - ) - """ - - """ - if ( - not controlnet_cond_image_is_pil - and not controlnet_cond_image_is_tensor - and not controlnet_cond_image_is_pil_list - and not controlnet_cond_image_is_tensor_list - ): - raise TypeError( - "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" - ) - """ - - """ - if controlnet_cond_image_is_pil: - controlnet_cond_image_batch_size = 1 - elif controlnet_cond_image_is_tensor: - controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0] - elif controlnet_cond_image_is_pil_list: - controlnet_cond_image_batch_size = len(controlnet_conditioning_image) - elif controlnet_cond_image_is_tensor_list: - controlnet_cond_image_batch_size = len(controlnet_conditioning_image) - """ - # check controlnet condition image if isinstance(self.controlnet, ControlNetModel): self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds) @@ -674,22 +639,6 @@ def check_inputs( else: assert False - """ - if prompt is not None and isinstance(prompt, str): - prompt_batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - prompt_batch_size = len(prompt) - elif prompt_embeds is not None: - prompt_batch_size = prompt_embeds.shape[0] - """ - - """ - if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size: - raise ValueError( - f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}" - ) - """ - # Check `controlnet_conditioning_scale` if isinstance(self.controlnet, ControlNetModel): if not isinstance(controlnet_conditioning_scale, float): @@ -892,7 +841,6 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - # controlnet_conditioning_scale: float = 1.0, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, ): r""" @@ -1061,18 +1009,6 @@ def __call__( else: assert False - """ - controlnet_conditioning_image = prepare_controlnet_conditioning_image( - controlnet_conditioning_image, - width, - height, - batch_size * num_images_per_prompt, - num_images_per_prompt, - device, - self.controlnet.dtype, - ) - """ - masked_image = image * (mask_image < 0.5) # 5. Prepare timesteps @@ -1113,11 +1049,6 @@ def __call__( do_classifier_free_guidance, ) - """ - if do_classifier_free_guidance: - controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2) - """ - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -1147,14 +1078,6 @@ def __call__( return_dict=False, ) - """ - down_block_res_samples = [ - down_block_res_sample * controlnet_conditioning_scale - for down_block_res_sample in down_block_res_samples - ] - mid_block_res_sample *= controlnet_conditioning_scale - """ - # predict the noise residual noise_pred = self.unet( inpainting_latent_model_input,