From abe8d6311d4b7f5b9409ca709c7fabf80d06c1a9 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Thu, 4 May 2023 14:45:48 +0900 Subject: [PATCH 1/5] add inferring_controlnet_cond_batch --- .../pipeline_stable_diffusion_controlnet.py | 38 +++++++++++-------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index 3bd7f82d7eb6..46c229e825ba 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -658,8 +658,7 @@ def prepare_image( num_images_per_prompt, device, dtype, - do_classifier_free_guidance=False, - guess_mode=False, + inferring_controlnet_cond_batch=False, ): if not isinstance(image, torch.Tensor): if isinstance(image, PIL.Image.Image): @@ -696,7 +695,7 @@ def prepare_image( image = image.to(device=device, dtype=dtype) - if do_classifier_free_guidance and not guess_mode: + if not inferring_controlnet_cond_batch: image = torch.cat([image] * 2) return image @@ -898,7 +897,16 @@ def __call__( if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets) - # 3. Encode input prompt + # 3. Determination of whether to infer ControlNet using only for the conditional batch. + global_pool_conditions = False + if isinstance(self.controlnet, ControlNetModel): + global_pool_conditions = self.controlnet.config.global_pool_conditions + else: + ... # TODO: Implement for MultiControlNetModel + + inferring_controlnet_cond_batch = (guess_mode or global_pool_conditions) and do_classifier_free_guidance + + # 4. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, @@ -909,7 +917,7 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) - # 4. Prepare image + # 5. Prepare image if isinstance(self.controlnet, ControlNetModel): image = self.prepare_image( image=image, @@ -919,8 +927,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=self.controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - guess_mode=guess_mode, + inferring_controlnet_cond_batch=inferring_controlnet_cond_batch, ) elif isinstance(self.controlnet, MultiControlNetModel): images = [] @@ -934,8 +941,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=self.controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - guess_mode=guess_mode, + inferring_controlnet_cond_batch=inferring_controlnet_cond_batch, ) images.append(image_) @@ -944,11 +950,11 @@ def __call__( else: assert False - # 5. Prepare timesteps + # 6. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps - # 6. Prepare latent variables + # 7. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, @@ -961,10 +967,10 @@ def __call__( latents, ) - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 8. Denoising loop + # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -973,8 +979,8 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference - if guess_mode and do_classifier_free_guidance: - # Infer ControlNet only for the conditional batch. + if inferring_controlnet_cond_batch: + # Inferring ControlNet only for the conditional batch. controlnet_latent_model_input = latents controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: @@ -991,7 +997,7 @@ def __call__( return_dict=False, ) - if guess_mode and do_classifier_free_guidance: + if inferring_controlnet_cond_batch: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. From a0fa8fe515b72c7df767ea3eb761cb009182f678 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Mon, 8 May 2023 23:33:37 +0900 Subject: [PATCH 2/5] Revert "add inferring_controlnet_cond_batch" This reverts commit abe8d6311d4b7f5b9409ca709c7fabf80d06c1a9. --- .../pipeline_stable_diffusion_controlnet.py | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index 66d570da9120..e9e9b9bb250e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -678,7 +678,8 @@ def prepare_image( num_images_per_prompt, device, dtype, - inferring_controlnet_cond_batch=False, + do_classifier_free_guidance=False, + guess_mode=False, ): if not isinstance(image, torch.Tensor): if isinstance(image, PIL.Image.Image): @@ -715,7 +716,7 @@ def prepare_image( image = image.to(device=device, dtype=dtype) - if not inferring_controlnet_cond_batch: + if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image @@ -917,16 +918,7 @@ def __call__( if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets) - # 3. Determination of whether to infer ControlNet using only for the conditional batch. - global_pool_conditions = False - if isinstance(self.controlnet, ControlNetModel): - global_pool_conditions = self.controlnet.config.global_pool_conditions - else: - ... # TODO: Implement for MultiControlNetModel - - inferring_controlnet_cond_batch = (guess_mode or global_pool_conditions) and do_classifier_free_guidance - - # 4. Encode input prompt + # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, @@ -954,7 +946,8 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=self.controlnet.dtype, - inferring_controlnet_cond_batch=inferring_controlnet_cond_batch, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, ) elif ( isinstance(self.controlnet, MultiControlNetModel) @@ -972,7 +965,8 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=self.controlnet.dtype, - inferring_controlnet_cond_batch=inferring_controlnet_cond_batch, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, ) images.append(image_) @@ -981,11 +975,11 @@ def __call__( else: assert False - # 6. Prepare timesteps + # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps - # 7. Prepare latent variables + # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, @@ -998,10 +992,10 @@ def __call__( latents, ) - # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 9. Denoising loop + # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1010,8 +1004,8 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference - if inferring_controlnet_cond_batch: - # Inferring ControlNet only for the conditional batch. + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. controlnet_latent_model_input = latents controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: @@ -1028,7 +1022,7 @@ def __call__( return_dict=False, ) - if inferring_controlnet_cond_batch: + if guess_mode and do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. From c9975c8122968b4e99acbb703289d50c6dac52da Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Mon, 8 May 2023 23:45:29 +0900 Subject: [PATCH 3/5] set guess_mode to True whenever global_pool_conditions is True Co-authored-by: Patrick von Platen --- src/diffusers/models/controlnet.py | 2 +- .../pipeline_stable_diffusion_controlnet.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 7b36d2eed96a..0b0ce0be547f 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -558,7 +558,7 @@ def forward( mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling - if guess_mode: + if guess_mode and not self.config.global_pool_conditions: scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 scales = scales * conditioning_scale diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index e9e9b9bb250e..5f7b1c164135 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -918,6 +918,13 @@ def __call__( if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets) + global_pool_conditions = ( + self.controlnet.config.global_pool_conditions + if isinstance(self.controlnet, ControlNetModel) + else self.controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, From 7704ec9fa0631df408291c40b7d18fdf34c1c31d Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Mon, 8 May 2023 23:47:08 +0900 Subject: [PATCH 4/5] nit --- .../stable_diffusion/pipeline_stable_diffusion_controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index 5f7b1c164135..f8c3856c015e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -936,7 +936,7 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) - # 5. Prepare image + # 4. Prepare image is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) From 954ef9722fe0957bbefab59a6209658618f46a85 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Tue, 9 May 2023 00:09:04 +0900 Subject: [PATCH 5/5] add integration test --- .../test_stable_diffusion_controlnet.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py index 279df4a32b29..c7bb49771e98 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py @@ -622,6 +622,37 @@ def test_stable_diffusion_compile(self): assert np.abs(expected_image - image).max() < 1e-1 + def test_v11_shuffle_global_pool_conditions(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11e_sd15_shuffle") + + pipe = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "New York" + image = load_image( + "https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/main/images/control.png" + ) + + output = pipe( + prompt, + image, + generator=generator, + output_type="np", + num_inference_steps=3, + guidance_scale=7.0, + ) + + image = output.images[0] + assert image.shape == (512, 640, 3) + + image_slice = image[-3:, -3:, -1] + expected_slice = np.array([0.1338, 0.1597, 0.1202, 0.1687, 0.1377, 0.1017, 0.2070, 0.1574, 0.1348]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + @slow @require_torch_gpu