Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 225 additions & 59 deletions pipeline/pipeline_controlnet_union_inpaint_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,31 +1388,31 @@ def __call__(
)

# 1. Check inputs
for control_image in control_image_list:
if control_image:
self.check_inputs(
prompt,
prompt_2,
control_image,
mask_image,
strength,
num_inference_steps,
callback_steps,
output_type,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
padding_mask_crop,
)
# for control_image in control_image_list:
# if control_image:
# self.check_inputs(
# prompt,
# prompt_2,
# control_image,
# mask_image,
# strength,
# num_inference_steps,
# callback_steps,
# output_type,
# negative_prompt,
# negative_prompt_2,
# prompt_embeds,
# negative_prompt_embeds,
# ip_adapter_image,
# ip_adapter_image_embeds,
# pooled_prompt_embeds,
# negative_pooled_prompt_embeds,
# controlnet_conditioning_scale,
# control_guidance_start,
# control_guidance_end,
# callback_on_step_end_tensor_inputs,
# padding_mask_crop,
# )

self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
Expand Down Expand Up @@ -1507,25 +1507,78 @@ def denoising_value_valid(dnv):
init_image = init_image.to(dtype=torch.float32)

# 5.2 Prepare control images
for idx in range(len(control_image_list)):
if control_image_list[idx]:
control_image = self.prepare_control_image(
image=control_image_list[idx],
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
crops_coords=crops_coords,
resize_mode=resize_mode,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = control_image.shape[-2:]
control_image_list[idx] = control_image


# if control image list is a list of image
original_control_image_list = control_image_list

if isinstance(control_image_list, list) and not isinstance(control_image_list[0], list):
for idx in range(len(control_image_list)):
if control_image_list[idx]:
control_image = self.prepare_control_image(
image=control_image_list[idx],
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
crops_coords=crops_coords,
resize_mode=resize_mode,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = control_image.shape[-2:]
control_image_list[idx] = control_image

elif isinstance(control_image_list, list) and isinstance(control_image_list[0], list):
for sub_idx, sub_control_image_list in enumerate(control_image_list):
for idx in range(len(sub_control_image_list)):
if sub_control_image_list[idx] != 0:
control_image = self.prepare_control_image(
image=sub_control_image_list[idx],
width=width,
height=height,
batch_size=1 * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
crops_coords=crops_coords,
resize_mode=resize_mode,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = control_image.shape[-2:]
sub_control_image_list[idx] = control_image

# concat tensor control_image at the same position between sub_control_image_list in control_image_list
merged_control_image_list = [0] * len(control_image_list[0])
for idx in range(len(merged_control_image_list)):
tensors_to_concat = []
for sub_list in control_image_list:
val = sub_list[idx]
if isinstance(val, torch.Tensor):
tensors_to_concat.append(val)
if len(tensors_to_concat) > 0:
# After concat, we have [img1_uncond, img1_cond, img2_uncond, img2_cond, ...]
concatenated = torch.cat(tensors_to_concat, dim=0)

# Need to reorder to [img1_uncond, img2_uncond, ..., img1_cond, img2_cond, ...]
if self.do_classifier_free_guidance and len(tensors_to_concat) > 1:
num_images = len(tensors_to_concat)
# Each tensor is [uncond, cond], so total is num_images * 2
reordered = []
# Collect all uncond first
for img_idx in range(num_images):
reordered.append(concatenated[img_idx * 2:img_idx * 2 + 1])
# Then collect all cond
for img_idx in range(num_images):
reordered.append(concatenated[img_idx * 2 + 1:img_idx * 2 + 2])
merged_control_image_list[idx] = torch.cat(reordered, dim=0)
else:
merged_control_image_list[idx] = concatenated
else:
merged_control_image_list[idx] = 0
control_image_list = merged_control_image_list

# 5.3 Prepare mask
mask = self.mask_processor.preprocess(
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
Expand Down Expand Up @@ -1666,7 +1719,7 @@ def denoising_value_valid(dnv):
)
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
Expand Down Expand Up @@ -1701,20 +1754,133 @@ def denoising_value_valid(dnv):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]

# # Resize control_image to match the size of the input to the controlnet
# if control_image.shape[-2:] != control_model_input.shape[-2:]:
# control_image = F.interpolate(control_image, size=control_model_input.shape[-2:], mode="bilinear", align_corners=False)

down_block_res_samples, mid_block_res_sample = self.controlnet(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond_list=control_image_list,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
added_cond_kwargs=controlnet_added_cond_kwargs,
return_dict=False,
)
# Process multiple control images separately
if isinstance(original_control_image_list, list) and isinstance(original_control_image_list[0], list):
num_control_images = len(original_control_image_list)

down_block_res_samples_all = []
mid_block_res_sample_all = []

for idx in range(num_control_images):
# Extract the control images for this specific input
tmp_control_image_list = []
for img_list in control_image_list:
if img_list is not None and not isinstance(img_list, int):
# Each control type might have batch of images
# Split by number of images (not by CFG)
if self.do_classifier_free_guidance:
# Format is [img0_uncond, img1_uncond, img0_cond, img1_cond]
# We need to extract [img{idx}_uncond, img{idx}_cond]
batch_per_cfg = num_control_images
uncond_idx = idx
cond_idx = batch_per_cfg + idx
tmp_img = torch.cat([img_list[uncond_idx:uncond_idx+1], img_list[cond_idx:cond_idx+1]], dim=0)
else:
tmp_img = img_list[idx:idx+1]
tmp_control_image_list.append(tmp_img)
else:
tmp_control_image_list.append(img_list)

# Extract latents and embeddings for this specific image
if self.do_classifier_free_guidance:
# control_model_input format: [img0_uncond, img1_uncond, img0_cond, img1_cond]
batch_per_cfg = num_control_images
uncond_idx = idx
cond_idx = batch_per_cfg + idx
tmp_control_model_input = torch.cat([
control_model_input[uncond_idx:uncond_idx+1],
control_model_input[cond_idx:cond_idx+1]
], dim=0)
tmp_controlnet_prompt_embeds = torch.cat([
controlnet_prompt_embeds[uncond_idx:uncond_idx+1],
controlnet_prompt_embeds[cond_idx:cond_idx+1]
], dim=0)

# Extract corresponding text embeds and time ids
tmp_add_text_embeds = torch.cat([
add_text_embeds[uncond_idx:uncond_idx+1],
add_text_embeds[cond_idx:cond_idx+1]
], dim=0)
tmp_add_time_ids = torch.cat([
add_time_ids[uncond_idx:uncond_idx+1],
add_time_ids[cond_idx:cond_idx+1]
], dim=0)
tmp_control_type = torch.cat([
controlnet_added_cond_kwargs["control_type"][uncond_idx:uncond_idx+1],
controlnet_added_cond_kwargs["control_type"][cond_idx:cond_idx+1]
], dim=0)
else:
tmp_control_model_input = control_model_input[idx:idx+1]
tmp_controlnet_prompt_embeds = controlnet_prompt_embeds[idx:idx+1]
tmp_add_text_embeds = add_text_embeds[idx:idx+1]
tmp_add_time_ids = add_time_ids[idx:idx+1]
tmp_control_type = controlnet_added_cond_kwargs["control_type"][idx:idx+1]

tmp_controlnet_added_cond_kwargs = {
"text_embeds": tmp_add_text_embeds,
"time_ids": tmp_add_time_ids,
"control_type": tmp_control_type,
}

# Run controlnet for this specific image
tmp_down_block_res_samples, tmp_mid_block_res_sample = self.controlnet(
tmp_control_model_input,
t,
encoder_hidden_states=tmp_controlnet_prompt_embeds,
controlnet_cond_list=tmp_control_image_list,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
added_cond_kwargs=tmp_controlnet_added_cond_kwargs,
return_dict=False,
)

down_block_res_samples_all.append(tmp_down_block_res_samples)
mid_block_res_sample_all.append(tmp_mid_block_res_sample)

# Concatenate results maintaining the batch order
# Result should be [img0_uncond, img1_uncond, img0_cond, img1_cond] to match control_model_input
if self.do_classifier_free_guidance:
# Reorder: [img0_uncond, img0_cond, img1_uncond, img1_cond] -> [img0_uncond, img1_uncond, img0_cond, img1_cond]
down_block_res_samples = []
for layer_idx in range(len(down_block_res_samples_all[0])):
uncond_samples = []
cond_samples = []
for img_idx in range(num_control_images):
# Each sample is [uncond, cond] for one image
sample = down_block_res_samples_all[img_idx][layer_idx]
uncond_samples.append(sample[0:1])
cond_samples.append(sample[1:2])
# Concatenate: [all_uncond, all_cond]
down_block_res_samples.append(torch.cat(uncond_samples + cond_samples, dim=0))

# Same for mid block
uncond_mid_samples = []
cond_mid_samples = []
for img_idx in range(num_control_images):
sample = mid_block_res_sample_all[img_idx]
uncond_mid_samples.append(sample[0:1])
cond_mid_samples.append(sample[1:2])
mid_block_res_sample = torch.cat(uncond_mid_samples + cond_mid_samples, dim=0)
else:
# No CFG, simple concatenation
down_block_res_samples = [
torch.cat([samples[layer_idx] for samples in down_block_res_samples_all], dim=0)
for layer_idx in range(len(down_block_res_samples_all[0]))
]
mid_block_res_sample = torch.cat(mid_block_res_sample_all, dim=0)

else:
# Single control image or standard processing
down_block_res_samples, mid_block_res_sample = self.controlnet(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond_list=control_image_list,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
added_cond_kwargs=controlnet_added_cond_kwargs,
return_dict=False,
)

if guess_mode and self.do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch.
Expand Down
Loading