diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 9a8d5edc7e07..e0f039c7d51f 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1104,15 +1104,22 @@ def __call__( latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + # controlnet(s) inference if guess_mode and do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] @@ -1122,7 +1129,6 @@ def __call__( controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, @@ -1130,7 +1136,7 @@ def __call__( controlnet_cond=image, conditioning_scale=cond_scale, guess_mode=guess_mode, - added_cond_kwargs=added_cond_kwargs, + added_cond_kwargs=controlnet_added_cond_kwargs, return_dict=False, ) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 2496b9c48ab8..8fb76499dc14 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -300,6 +300,28 @@ def test_stable_diffusion_xl_prompt_embeds(self): # make sure that it's equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + def test_controlnet_sdxl_guess(self): + device = "cpu" + + components = self.get_dummy_components() + + sd_pipe = self.pipeline_class(**components) + sd_pipe = sd_pipe.to(device) + + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["guess_mode"] = True + + output = sd_pipe(**inputs) + image_slice = output.images[0, -3:, -3:, -1] + expected_slice = np.array( + [0.7330834, 0.590667, 0.5667336, 0.6029023, 0.5679491, 0.5968194, 0.4032986, 0.47612396, 0.5089609] + ) + + # make sure that it's equal + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-4 + class StableDiffusionXLMultiControlNetPipelineFastTests( PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase