diff --git a/pipeline/pipeline_controlnet_union_inpaint_sd_xl.py b/pipeline/pipeline_controlnet_union_inpaint_sd_xl.py index b0028aa..808d65e 100644 --- a/pipeline/pipeline_controlnet_union_inpaint_sd_xl.py +++ b/pipeline/pipeline_controlnet_union_inpaint_sd_xl.py @@ -1388,31 +1388,31 @@ def __call__( ) # 1. Check inputs - for control_image in control_image_list: - if control_image: - self.check_inputs( - prompt, - prompt_2, - control_image, - mask_image, - strength, - num_inference_steps, - callback_steps, - output_type, - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - ip_adapter_image, - ip_adapter_image_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, - callback_on_step_end_tensor_inputs, - padding_mask_crop, - ) + # for control_image in control_image_list: + # if control_image: + # self.check_inputs( + # prompt, + # prompt_2, + # control_image, + # mask_image, + # strength, + # num_inference_steps, + # callback_steps, + # output_type, + # negative_prompt, + # negative_prompt_2, + # prompt_embeds, + # negative_prompt_embeds, + # ip_adapter_image, + # ip_adapter_image_embeds, + # pooled_prompt_embeds, + # negative_pooled_prompt_embeds, + # controlnet_conditioning_scale, + # control_guidance_start, + # control_guidance_end, + # callback_on_step_end_tensor_inputs, + # padding_mask_crop, + # ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip @@ -1507,25 +1507,78 @@ def denoising_value_valid(dnv): init_image = init_image.to(dtype=torch.float32) # 5.2 Prepare control images - for idx in range(len(control_image_list)): - if control_image_list[idx]: - control_image = self.prepare_control_image( - image=control_image_list[idx], - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - crops_coords=crops_coords, - resize_mode=resize_mode, - do_classifier_free_guidance=self.do_classifier_free_guidance, - guess_mode=guess_mode, - ) - height, width = control_image.shape[-2:] - control_image_list[idx] = control_image - - + # if control image list is a list of image + original_control_image_list = control_image_list + + if isinstance(control_image_list, list) and not isinstance(control_image_list[0], list): + for idx in range(len(control_image_list)): + if control_image_list[idx]: + control_image = self.prepare_control_image( + image=control_image_list[idx], + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = control_image.shape[-2:] + control_image_list[idx] = control_image + + elif isinstance(control_image_list, list) and isinstance(control_image_list[0], list): + for sub_idx, sub_control_image_list in enumerate(control_image_list): + for idx in range(len(sub_control_image_list)): + if sub_control_image_list[idx] != 0: + control_image = self.prepare_control_image( + image=sub_control_image_list[idx], + width=width, + height=height, + batch_size=1 * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = control_image.shape[-2:] + sub_control_image_list[idx] = control_image + + # concat tensor control_image at the same position between sub_control_image_list in control_image_list + merged_control_image_list = [0] * len(control_image_list[0]) + for idx in range(len(merged_control_image_list)): + tensors_to_concat = [] + for sub_list in control_image_list: + val = sub_list[idx] + if isinstance(val, torch.Tensor): + tensors_to_concat.append(val) + if len(tensors_to_concat) > 0: + # After concat, we have [img1_uncond, img1_cond, img2_uncond, img2_cond, ...] + concatenated = torch.cat(tensors_to_concat, dim=0) + + # Need to reorder to [img1_uncond, img2_uncond, ..., img1_cond, img2_cond, ...] + if self.do_classifier_free_guidance and len(tensors_to_concat) > 1: + num_images = len(tensors_to_concat) + # Each tensor is [uncond, cond], so total is num_images * 2 + reordered = [] + # Collect all uncond first + for img_idx in range(num_images): + reordered.append(concatenated[img_idx * 2:img_idx * 2 + 1]) + # Then collect all cond + for img_idx in range(num_images): + reordered.append(concatenated[img_idx * 2 + 1:img_idx * 2 + 2]) + merged_control_image_list[idx] = torch.cat(reordered, dim=0) + else: + merged_control_image_list[idx] = concatenated + else: + merged_control_image_list[idx] = 0 + control_image_list = merged_control_image_list + # 5.3 Prepare mask mask = self.mask_processor.preprocess( mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords @@ -1666,7 +1719,7 @@ def denoising_value_valid(dnv): ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] - + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance @@ -1701,20 +1754,133 @@ def denoising_value_valid(dnv): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] - # # Resize control_image to match the size of the input to the controlnet - # if control_image.shape[-2:] != control_model_input.shape[-2:]: - # control_image = F.interpolate(control_image, size=control_model_input.shape[-2:], mode="bilinear", align_corners=False) - - down_block_res_samples, mid_block_res_sample = self.controlnet( - control_model_input, - t, - encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond_list=control_image_list, - conditioning_scale=cond_scale, - guess_mode=guess_mode, - added_cond_kwargs=controlnet_added_cond_kwargs, - return_dict=False, - ) + # Process multiple control images separately + if isinstance(original_control_image_list, list) and isinstance(original_control_image_list[0], list): + num_control_images = len(original_control_image_list) + + down_block_res_samples_all = [] + mid_block_res_sample_all = [] + + for idx in range(num_control_images): + # Extract the control images for this specific input + tmp_control_image_list = [] + for img_list in control_image_list: + if img_list is not None and not isinstance(img_list, int): + # Each control type might have batch of images + # Split by number of images (not by CFG) + if self.do_classifier_free_guidance: + # Format is [img0_uncond, img1_uncond, img0_cond, img1_cond] + # We need to extract [img{idx}_uncond, img{idx}_cond] + batch_per_cfg = num_control_images + uncond_idx = idx + cond_idx = batch_per_cfg + idx + tmp_img = torch.cat([img_list[uncond_idx:uncond_idx+1], img_list[cond_idx:cond_idx+1]], dim=0) + else: + tmp_img = img_list[idx:idx+1] + tmp_control_image_list.append(tmp_img) + else: + tmp_control_image_list.append(img_list) + + # Extract latents and embeddings for this specific image + if self.do_classifier_free_guidance: + # control_model_input format: [img0_uncond, img1_uncond, img0_cond, img1_cond] + batch_per_cfg = num_control_images + uncond_idx = idx + cond_idx = batch_per_cfg + idx + tmp_control_model_input = torch.cat([ + control_model_input[uncond_idx:uncond_idx+1], + control_model_input[cond_idx:cond_idx+1] + ], dim=0) + tmp_controlnet_prompt_embeds = torch.cat([ + controlnet_prompt_embeds[uncond_idx:uncond_idx+1], + controlnet_prompt_embeds[cond_idx:cond_idx+1] + ], dim=0) + + # Extract corresponding text embeds and time ids + tmp_add_text_embeds = torch.cat([ + add_text_embeds[uncond_idx:uncond_idx+1], + add_text_embeds[cond_idx:cond_idx+1] + ], dim=0) + tmp_add_time_ids = torch.cat([ + add_time_ids[uncond_idx:uncond_idx+1], + add_time_ids[cond_idx:cond_idx+1] + ], dim=0) + tmp_control_type = torch.cat([ + controlnet_added_cond_kwargs["control_type"][uncond_idx:uncond_idx+1], + controlnet_added_cond_kwargs["control_type"][cond_idx:cond_idx+1] + ], dim=0) + else: + tmp_control_model_input = control_model_input[idx:idx+1] + tmp_controlnet_prompt_embeds = controlnet_prompt_embeds[idx:idx+1] + tmp_add_text_embeds = add_text_embeds[idx:idx+1] + tmp_add_time_ids = add_time_ids[idx:idx+1] + tmp_control_type = controlnet_added_cond_kwargs["control_type"][idx:idx+1] + + tmp_controlnet_added_cond_kwargs = { + "text_embeds": tmp_add_text_embeds, + "time_ids": tmp_add_time_ids, + "control_type": tmp_control_type, + } + + # Run controlnet for this specific image + tmp_down_block_res_samples, tmp_mid_block_res_sample = self.controlnet( + tmp_control_model_input, + t, + encoder_hidden_states=tmp_controlnet_prompt_embeds, + controlnet_cond_list=tmp_control_image_list, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=tmp_controlnet_added_cond_kwargs, + return_dict=False, + ) + + down_block_res_samples_all.append(tmp_down_block_res_samples) + mid_block_res_sample_all.append(tmp_mid_block_res_sample) + + # Concatenate results maintaining the batch order + # Result should be [img0_uncond, img1_uncond, img0_cond, img1_cond] to match control_model_input + if self.do_classifier_free_guidance: + # Reorder: [img0_uncond, img0_cond, img1_uncond, img1_cond] -> [img0_uncond, img1_uncond, img0_cond, img1_cond] + down_block_res_samples = [] + for layer_idx in range(len(down_block_res_samples_all[0])): + uncond_samples = [] + cond_samples = [] + for img_idx in range(num_control_images): + # Each sample is [uncond, cond] for one image + sample = down_block_res_samples_all[img_idx][layer_idx] + uncond_samples.append(sample[0:1]) + cond_samples.append(sample[1:2]) + # Concatenate: [all_uncond, all_cond] + down_block_res_samples.append(torch.cat(uncond_samples + cond_samples, dim=0)) + + # Same for mid block + uncond_mid_samples = [] + cond_mid_samples = [] + for img_idx in range(num_control_images): + sample = mid_block_res_sample_all[img_idx] + uncond_mid_samples.append(sample[0:1]) + cond_mid_samples.append(sample[1:2]) + mid_block_res_sample = torch.cat(uncond_mid_samples + cond_mid_samples, dim=0) + else: + # No CFG, simple concatenation + down_block_res_samples = [ + torch.cat([samples[layer_idx] for samples in down_block_res_samples_all], dim=0) + for layer_idx in range(len(down_block_res_samples_all[0])) + ] + mid_block_res_sample = torch.cat(mid_block_res_sample_all, dim=0) + + else: + # Single control image or standard processing + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond_list=control_image_list, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. diff --git a/promax/controlnet_union_test_inpainting.py b/promax/controlnet_union_test_inpainting.py index 81befdc..ed12716 100644 --- a/promax/controlnet_union_test_inpainting.py +++ b/promax/controlnet_union_test_inpainting.py @@ -1,4 +1,5 @@ # diffusers测试ControlNet +import time import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" import sys @@ -78,17 +79,19 @@ def HWC3(x): mask_gen = get_mask_generator(kind='mixed', kwargs=mask_gen_kwargs) -prompt = "your prompt, the longer the better, you can describe it as detail as possible" +prompt = "a cat and a dog playing football on the field, high quality, detailed painting, artstation" negative_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' seed = random.randint(0, 2147483647) # The original image you want to repaint. -original_img = cv2.imread("your image path") +original_img = cv2.imread("your image path 1") +original_img_2 = cv2.imread("your image path 2") + # # inpainting support any mask shape # # where you want to repaint, the mask image should be a binary image, with value 0 or 255. -# mask = cv2.imread("your mask image path") +# mask = cv2.imread("your mask image path") height, width, _ = original_img.shape ratio = np.sqrt(1024. * 1024. / (width * height)) @@ -96,18 +99,33 @@ def HWC3(x): original_img = cv2.resize(original_img, (W, H)) original_img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB) +original_img_2 = cv2.resize(original_img_2, (W, H)) +original_img_2 = cv2.cvtColor(original_img_2, cv2.COLOR_BGR2RGB) + import copy controlnet_img = copy.deepcopy(original_img) controlnet_img = np.transpose(controlnet_img, (2, 0, 1)) + +controlnet_img_2 = copy.deepcopy(original_img_2) +controlnet_img_2 = np.transpose(controlnet_img_2, (2, 0, 1)) + mask = mask_gen(controlnet_img) -controlnet_img = np.transpose(controlnet_img, (1, 2, 0)) mask = np.transpose(mask, (1, 2, 0)) +controlnet_img = np.transpose(controlnet_img, (1, 2, 0)) +controlnet_img_2 = np.transpose(controlnet_img_2, (1, 2, 0)) + controlnet_img[mask.squeeze() > 0.0] = 0 +controlnet_img_2[mask.squeeze() > 0.0] = 0 + mask = HWC3((mask * 255).astype('uint8')) controlnet_img = Image.fromarray(controlnet_img) +controlnet_img_2 = Image.fromarray(controlnet_img_2) + original_img = Image.fromarray(original_img) +original_img_2 = Image.fromarray(original_img_2) + mask = Image.fromarray(mask) width, height = W, H @@ -120,19 +138,53 @@ def HWC3(x): # 5 -- segment # 6 -- tile # 7 -- repaint -images = pipe(prompt=[prompt]*1, + +control_image_list = [ + [0, 0, 0, 0, 0, 0, 0, controlnet_img], + [0, 0, 0, 0, 0, 0, 0, controlnet_img_2], +] + + +generator = torch.Generator('cuda').manual_seed(seed) +start_time = time.time() +images = pipe( + prompt=[prompt]*1, image=original_img, mask_image=mask, control_image_list=[0, 0, 0, 0, 0, 0, 0, controlnet_img], negative_prompt=[negative_prompt]*1, - # generator=generator, + generator=generator, + width=width, + height=height, + num_inference_steps=12, + union_control=True, + union_control_type=torch.Tensor([0, 0, 0, 0, 0, 0, 0, 1]), + ).images +end_time = time.time() +print(f"Single controlnet image inference time: {end_time - start_time} seconds") +for i in range(len(images)): + images[i].save(f"output_single_control_{i}.png") + +generator = torch.Generator('cuda').manual_seed(seed) +start_time = time.time() +images = pipe( + prompt=[prompt]*2, + image=[original_img, original_img_2], + mask_image=mask, + control_image_list=control_image_list, + negative_prompt=[negative_prompt]*2, + generator=generator, width=width, height=height, - num_inference_steps=30, + num_inference_steps=12, union_control=True, union_control_type=torch.Tensor([0, 0, 0, 0, 0, 0, 0, 1]), + guidance_scale=12.0, ).images +end_time = time.time() +print(f"Batch processing ControlNet Union inference time: {end_time - start_time} seconds") + +for i in range(len(images)): + images[i].save(f"output_{i}.png") -controlnet_img.save("control_inpainting.webp") -images[0].save(f"your image save path, png format is usually better than jpg or webp in terms of image quality but got much bigger")