Skip to content
34 changes: 21 additions & 13 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,32 +1057,41 @@ def __call__(
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
)

prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)

use_controlnet_conditional_branch_only = guess_mode and do_classifier_free_guidance
if use_controlnet_conditional_branch_only:
controlnet_prompt_embeds = prompt_embeds
controlnet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}

if do_classifier_free_guidance:
# expand all inputs if we are doing classifier free guidance
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)

prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}

if not use_controlnet_conditional_branch_only:
# controlnet inputs are the same as inputs of base unet model
controlnet_prompt_embeds = prompt_embeds
controlnet_added_cond_kwargs = added_cond_kwargs

# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
scaled_latents = self.scheduler.scale_model_input(latents, t)
latent_model_input = torch.cat([scaled_latents] * 2) if do_classifier_free_guidance else scaled_latents
Comment on lines +1087 to +1088
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could I have an explanation on why this differs from

latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanted to make the code around switching CFG more clear, I don't see the point behind scaling latents once for main unet and then the second time for controlnet. So I made a change which is aligned with all inputs (prompt embeds, text_embeds, time_ids): we first prepare the standard inputs for controlnet, and then expand them (torch.cat([...]*2) for unet if it is required by CFG

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Clean.

Copy link
Collaborator

@yiyixuxu yiyixuxu Aug 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just took a closer look here. Is it possible to do exactly the same as here? https://github.com/huggingface/diffusers/blob/80871ac5971fe7e708befa3b553463c4e61b22ab/src/diffusers/pipelines/controlnet/pipeline_controlnet.py#L938C19-L938C19

I think the logic is very clear there:

  1. all the code create inputs for the controlnet_model is addressed within the if ... else ... statement
  2. it is very clear to me that only when guess_mode and do_classifier_free_guidance it will differ from our regular Unet model input


# controlnet(s) inference
if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
if use_controlnet_conditional_branch_only:
control_model_input = scaled_latents
else:
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds

if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
Expand All @@ -1092,15 +1101,14 @@ def __call__(
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]

added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
down_block_res_samples, mid_block_res_sample = self.controlnet(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=image,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
added_cond_kwargs=added_cond_kwargs,
added_cond_kwargs=controlnet_added_cond_kwargs,
return_dict=False,
)

Expand Down
18 changes: 17 additions & 1 deletion tests/pipelines/controlnet/test_controlnet_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.

import unittest
from parameterized import parameterized_class
import itertools

import numpy as np
import torch
Expand Down Expand Up @@ -47,6 +49,10 @@
enable_full_determinism()


@parameterized_class(
("guess_mode", "guidance_scale"),
list(itertools.product([False, True], [1.0, 6.0]))
)
class StableDiffusionXLControlNetPipelineFastTests(
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
):
Expand All @@ -56,6 +62,12 @@ class StableDiffusionXLControlNetPipelineFastTests(
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS

@classmethod
def setUpClass(cls):
if cls == StableDiffusionXLControlNetPipelineFastTests:
raise unittest.SkipTest("`parameterized_class` bug, see https://github.com/wolever/parameterized/issues/119")
super().setUpClass()

def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
Expand Down Expand Up @@ -158,7 +170,8 @@ def get_dummy_inputs(self, device, seed=0):
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"guidance_scale": self.guidance_scale,
"guess_mode": self.guess_mode,
"output_type": "numpy",
"image": image,
}
Expand Down Expand Up @@ -234,6 +247,9 @@ def test_stable_diffusion_xl_multi_prompts(self):
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4

if self.guidance_scale <= 1: # negative prompt has no effect without CFG
return

# manually set a negative_prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["negative_prompt"] = "negative prompt"
Expand Down