diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index c4a8777c5ed0..4d51f97f9e31 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1057,32 +1057,41 @@ def __call__( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype ) + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + use_controlnet_conditional_branch_only = guess_mode and do_classifier_free_guidance + if use_controlnet_conditional_branch_only: + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if do_classifier_free_guidance: + # expand all inputs if we are doing classifier free guidance prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) - prompt_embeds = prompt_embeds.to(device) - add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + if not use_controlnet_conditional_branch_only: + # controlnet inputs are the same as inputs of base unet model + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + scaled_latents = self.scheduler.scale_model_input(latents, t) + latent_model_input = torch.cat([scaled_latents] * 2) if do_classifier_free_guidance else scaled_latents # controlnet(s) inference - if guess_mode and do_classifier_free_guidance: - # Infer ControlNet only for the conditional batch. - control_model_input = latents - control_model_input = self.scheduler.scale_model_input(control_model_input, t) - controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + if use_controlnet_conditional_branch_only: + control_model_input = scaled_latents else: control_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] @@ -1092,7 +1101,6 @@ def __call__( controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, @@ -1100,7 +1108,7 @@ def __call__( controlnet_cond=image, conditioning_scale=cond_scale, guess_mode=guess_mode, - added_cond_kwargs=added_cond_kwargs, + added_cond_kwargs=controlnet_added_cond_kwargs, return_dict=False, ) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 0c7ee4972f01..0e2769392fb0 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -14,6 +14,8 @@ # limitations under the License. import unittest +from parameterized import parameterized_class +import itertools import numpy as np import torch @@ -47,6 +49,10 @@ enable_full_determinism() +@parameterized_class( + ("guess_mode", "guidance_scale"), + list(itertools.product([False, True], [1.0, 6.0])) +) class StableDiffusionXLControlNetPipelineFastTests( PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase ): @@ -56,6 +62,12 @@ class StableDiffusionXLControlNetPipelineFastTests( image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + @classmethod + def setUpClass(cls): + if cls == StableDiffusionXLControlNetPipelineFastTests: + raise unittest.SkipTest("`parameterized_class` bug, see https://github.com/wolever/parameterized/issues/119") + super().setUpClass() + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( @@ -158,7 +170,8 @@ def get_dummy_inputs(self, device, seed=0): "prompt": "A painting of a squirrel eating a burger", "generator": generator, "num_inference_steps": 2, - "guidance_scale": 6.0, + "guidance_scale": self.guidance_scale, + "guess_mode": self.guess_mode, "output_type": "numpy", "image": image, } @@ -234,6 +247,9 @@ def test_stable_diffusion_xl_multi_prompts(self): # ensure the results are not equal assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 + if self.guidance_scale <= 1: # negative prompt has no effect without CFG + return + # manually set a negative_prompt inputs = self.get_dummy_inputs(torch_device) inputs["negative_prompt"] = "negative prompt"