From ce4c1e2eb92ccd4d433542de2721488692eaebbd Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 27 Aug 2023 05:16:31 +0000 Subject: [PATCH 1/5] fix --- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 9a8d5edc7e07..604ec5f5c2c8 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1104,15 +1104,19 @@ def __call__( latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + # controlnet(s) inference if guess_mode and do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = {"text_embeds": add_text_embeds.chunk(2)[1], "time_ids": add_time_ids.chunk(2)[1]} else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] @@ -1122,7 +1126,7 @@ def __call__( controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, @@ -1130,7 +1134,7 @@ def __call__( controlnet_cond=image, conditioning_scale=cond_scale, guess_mode=guess_mode, - added_cond_kwargs=added_cond_kwargs, + added_cond_kwargs=controlnet_added_cond_kwargs, return_dict=False, ) From 6caa3e212ffa98af49e48e6c482ca28674653400 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 27 Aug 2023 05:17:35 +0000 Subject: [PATCH 2/5] fix --- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 604ec5f5c2c8..e0f039c7d51f 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1112,7 +1112,10 @@ def __call__( control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] - controlnet_added_cond_kwargs = {"text_embeds": add_text_embeds.chunk(2)[1], "time_ids": add_time_ids.chunk(2)[1]} + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds @@ -1126,7 +1129,6 @@ def __call__( controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] - down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, From 66579d5bbb9945b2672f6178d3a3d532902a7d1a Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 28 Aug 2023 03:00:33 +0000 Subject: [PATCH 3/5] add a test --- .../controlnet/test_controlnet_sdxl.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 2496b9c48ab8..0f7fdff19626 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -300,6 +300,26 @@ def test_stable_diffusion_xl_prompt_embeds(self): # make sure that it's equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + def test_controlnet_sdxl_guess(self): + components = self.get_dummy_components() + sd_pipe = self.pipeline_class(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # forward without prompt embeds + inputs = self.get_dummy_inputs(torch_device) + inputs["guess_mode"] = True + + output = sd_pipe(**inputs) + image_slice = output.images[0, -3:, -3:, -1] + expected_slice = np.array( + [0.45368838, 0.38424692, 0.48627546, 0.4444831, 0.44288212, 0.42112032, 0.46950358, 0.5654028, 0.5325038] + ) + + # make sure that it's equal + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-4 + class StableDiffusionXLMultiControlNetPipelineFastTests( PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase From 950a75045658432d2834d6f76e5e4595b2c1dea1 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 28 Aug 2023 03:03:34 +0000 Subject: [PATCH 4/5] fix --- tests/pipelines/controlnet/test_controlnet_sdxl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 0f7fdff19626..38b09cc17366 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -307,7 +307,6 @@ def test_controlnet_sdxl_guess(self): sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) - # forward without prompt embeds inputs = self.get_dummy_inputs(torch_device) inputs["guess_mode"] = True From 82db3e63b00be444f876e905a350dc14880a334d Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 28 Aug 2023 09:00:16 +0000 Subject: [PATCH 5/5] move fast test to cpu --- tests/pipelines/controlnet/test_controlnet_sdxl.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 38b09cc17366..8fb76499dc14 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -301,19 +301,22 @@ def test_stable_diffusion_xl_prompt_embeds(self): assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 def test_controlnet_sdxl_guess(self): + device = "cpu" + components = self.get_dummy_components() + sd_pipe = self.pipeline_class(**components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe = sd_pipe.to(torch_device) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) - inputs = self.get_dummy_inputs(torch_device) + inputs = self.get_dummy_inputs(device) inputs["guess_mode"] = True output = sd_pipe(**inputs) image_slice = output.images[0, -3:, -3:, -1] expected_slice = np.array( - [0.45368838, 0.38424692, 0.48627546, 0.4444831, 0.44288212, 0.42112032, 0.46950358, 0.5654028, 0.5325038] + [0.7330834, 0.590667, 0.5667336, 0.6029023, 0.5679491, 0.5968194, 0.4032986, 0.47612396, 0.5089609] ) # make sure that it's equal