From 0914fa3901b464488442b37887f69610361a7b3f Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 5 Feb 2025 07:51:38 +0000 Subject: [PATCH 1/5] ControlNet Union scale --- .../models/controlnets/controlnet_union.py | 19 +++++++-------- .../pipeline_controlnet_union_sd_xl.py | 24 +++++++++++++------ 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_union.py b/src/diffusers/models/controlnets/controlnet_union.py index 076e966f3d37..827522f2b054 100644 --- a/src/diffusers/models/controlnets/controlnet_union.py +++ b/src/diffusers/models/controlnets/controlnet_union.py @@ -605,7 +605,7 @@ def forward( controlnet_cond: List[torch.Tensor], control_type: torch.Tensor, control_type_idx: List[int], - conditioning_scale: float = 1.0, + conditioning_scale: Union[float, List[float]] = 1.0, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -658,6 +658,9 @@ def forward( If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ + if isinstance(conditioning_scale, float): + conditioning_scale = [conditioning_scale] * len(controlnet_cond) + # check channel order channel_order = self.config.controlnet_conditioning_channel_order @@ -742,12 +745,12 @@ def forward( inputs = [] condition_list = [] - for cond, control_idx in zip(controlnet_cond, control_type_idx): + for cond, control_idx, scale in zip(controlnet_cond, control_type_idx, conditioning_scale): condition = self.controlnet_cond_embedding(cond) feat_seq = torch.mean(condition, dim=(2, 3)) feat_seq = feat_seq + self.task_embedding[control_idx] - inputs.append(feat_seq.unsqueeze(1)) - condition_list.append(condition) + inputs.append(feat_seq.unsqueeze(1) * scale) + condition_list.append(condition * scale) condition = sample feat_seq = torch.mean(condition, dim=(2, 3)) @@ -759,10 +762,10 @@ def forward( x = layer(x) controlnet_cond_fuser = sample * 0.0 - for idx, condition in enumerate(condition_list[:-1]): + for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale): alpha = self.spatial_ch_projs(x[:, idx]) alpha = alpha.unsqueeze(-1).unsqueeze(-1) - controlnet_cond_fuser += condition + alpha + controlnet_cond_fuser += condition + alpha * scale sample = sample + controlnet_cond_fuser @@ -806,12 +809,8 @@ def forward( # 6. scaling if guess_mode and not self.config.global_pool_conditions: scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 - scales = scales * conditioning_scale down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] mid_block_res_sample = mid_block_res_sample * scales[-1] # last one - else: - down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] - mid_block_res_sample = mid_block_res_sample * conditioning_scale if self.config.global_pool_conditions: down_block_res_samples = [ diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py index 27e627e5bac9..0acf93b31336 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py @@ -1132,20 +1132,29 @@ def __call__( controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + if not isinstance(control_mode, list): + control_mode = [control_mode] + # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(control_mode) + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + if isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(control_mode) if not isinstance(control_image, list): control_image = [control_image] else: control_image = control_image.copy() - if not isinstance(control_mode, list): - control_mode = [control_mode] - if len(control_image) != len(control_mode): raise ValueError("Expected len(control_image) == len(control_type)") @@ -1278,10 +1287,11 @@ def __call__( # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): - controlnet_keep.append( - 1.0 - - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end) - ) + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps) # 7.2 Prepare added time ids & embeddings original_size = original_size or (height, width) From 7160506adab68b998214259b1e390937a1d4846e Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 13 Feb 2025 15:53:51 +0000 Subject: [PATCH 2/5] fix --- .../pipeline_controlnet_union_sd_xl.py | 32 ++++--------------- 1 file changed, 7 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py index dfb62cc806d2..829a9cbb61b5 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py @@ -757,15 +757,9 @@ def check_inputs( for images_ in image: for image_ in images_: self.check_image(image_, prompt, prompt_embeds) - else: - assert False # Check `controlnet_conditioning_scale` - # TODO Update for https://github.com/huggingface/diffusers/pull/10723 - if isinstance(controlnet, ControlNetUnionModel): - if not isinstance(controlnet_conditioning_scale, float): - raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - elif isinstance(controlnet, MultiControlNetUnionModel): + if isinstance(controlnet, MultiControlNetUnionModel): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError("A single batch of multiple conditionings is not supported at the moment.") @@ -776,8 +770,6 @@ def check_inputs( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" ) - else: - assert False if len(control_guidance_start) != len(control_guidance_end): raise ValueError( @@ -808,8 +800,6 @@ def check_inputs( for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets): if max(_control_mode) >= _controlnet.config.num_control_type: raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.") - else: - assert False # Equal number of `image` and `control_mode` elements if isinstance(controlnet, ControlNetUnionModel): @@ -823,8 +813,6 @@ def check_inputs( elif sum(len(x) for x in image) != sum(len(x) for x in control_mode): raise ValueError("Expected len(control_image) == len(control_mode)") - else: - assert False if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( @@ -1201,6 +1189,11 @@ def __call__( controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + if not isinstance(control_image, list): + control_image = [control_image] + else: + control_image = control_image.copy() + if not isinstance(control_mode, list): control_mode = [control_mode] @@ -1216,15 +1209,7 @@ def __call__( mult * [control_guidance_end], ) - if not isinstance(control_image, list): - control_image = [control_image] - else: - control_image = control_image.copy() - - if not isinstance(control_mode, list): - control_mode = [control_mode] - - if isinstance(controlnet, MultiControlNetUnionModel) and isinstance(controlnet_conditioning_scale, float): + if isinstance(controlnet_conditioning_scale, float): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode) controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult @@ -1361,9 +1346,6 @@ def __call__( control_image = control_images height, width = control_image[0][0].shape[-2:] - else: - assert False - # 5. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas From a1f1c70ac4c937aad0433328c7223883d8eafb83 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 13 Feb 2025 16:21:05 +0000 Subject: [PATCH 3/5] universal interface --- .../pipelines/controlnet/pipeline_controlnet_union_sd_xl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py index 829a9cbb61b5..ca931c221eec 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py @@ -1197,6 +1197,10 @@ def __call__( if not isinstance(control_mode, list): control_mode = [control_mode] + if isinstance(controlnet, MultiControlNetUnionModel): + control_image = [[item] for item in control_image] + control_mode = [[item] for item in control_mode] + # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] From a54832dbd901c9d751b96886fe066e5e397cc71a Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 13 Feb 2025 16:48:59 +0000 Subject: [PATCH 4/5] from_multi --- .../models/controlnets/controlnet_union.py | 21 ++++++++++++++++--- .../controlnets/multicontrolnet_union.py | 5 ++++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_union.py b/src/diffusers/models/controlnets/controlnet_union.py index 827522f2b054..26cb86718a21 100644 --- a/src/diffusers/models/controlnets/controlnet_union.py +++ b/src/diffusers/models/controlnets/controlnet_union.py @@ -611,6 +611,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + from_multi: bool = False, guess_mode: bool = False, return_dict: bool = True, ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: @@ -647,6 +648,8 @@ def forward( Additional conditions for the Stable Diffusion XL UNet. cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + from_multi (`bool`, defaults to `False`): + Use standard scaling when called from `MultiControlNetUnionModel`. guess_mode (`bool`, defaults to `False`): In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. @@ -749,8 +752,12 @@ def forward( condition = self.controlnet_cond_embedding(cond) feat_seq = torch.mean(condition, dim=(2, 3)) feat_seq = feat_seq + self.task_embedding[control_idx] - inputs.append(feat_seq.unsqueeze(1) * scale) - condition_list.append(condition * scale) + if from_multi: + inputs.append(feat_seq.unsqueeze(1)) + condition_list.append(condition) + else: + inputs.append(feat_seq.unsqueeze(1) * scale) + condition_list.append(condition * scale) condition = sample feat_seq = torch.mean(condition, dim=(2, 3)) @@ -765,7 +772,10 @@ def forward( for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale): alpha = self.spatial_ch_projs(x[:, idx]) alpha = alpha.unsqueeze(-1).unsqueeze(-1) - controlnet_cond_fuser += condition + alpha * scale + if from_multi: + controlnet_cond_fuser += condition + alpha + else: + controlnet_cond_fuser += condition + alpha * scale sample = sample + controlnet_cond_fuser @@ -809,8 +819,13 @@ def forward( # 6. scaling if guess_mode and not self.config.global_pool_conditions: scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 + if from_multi: + scales = scales * conditioning_scale[0] down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] mid_block_res_sample = mid_block_res_sample * scales[-1] # last one + elif from_multi: + down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample * conditioning_scale[0] if self.config.global_pool_conditions: down_block_res_samples = [ diff --git a/src/diffusers/models/controlnets/multicontrolnet_union.py b/src/diffusers/models/controlnets/multicontrolnet_union.py index 6dbc0c97ff75..7ebcb33ea689 100644 --- a/src/diffusers/models/controlnets/multicontrolnet_union.py +++ b/src/diffusers/models/controlnets/multicontrolnet_union.py @@ -47,9 +47,12 @@ def forward( guess_mode: bool = False, return_dict: bool = True, ) -> Union[ControlNetOutput, Tuple]: + down_block_res_samples, mid_block_res_sample = None, None for i, (image, ctype, ctype_idx, scale, controlnet) in enumerate( zip(controlnet_cond, control_type, control_type_idx, conditioning_scale, self.nets) ): + if scale == 0.0: + continue down_samples, mid_sample = controlnet( sample=sample, timestep=timestep, @@ -68,7 +71,7 @@ def forward( ) # merge samples - if i == 0: + if down_block_res_samples is None and mid_block_res_sample is None: down_block_res_samples, mid_block_res_sample = down_samples, mid_sample else: down_block_res_samples = [ From ff515b12262ce4433a8f6f42e47b306f99245648 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 13 Feb 2025 16:52:54 +0000 Subject: [PATCH 5/5] from_multi --- src/diffusers/models/controlnets/multicontrolnet_union.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/controlnets/multicontrolnet_union.py b/src/diffusers/models/controlnets/multicontrolnet_union.py index 7ebcb33ea689..427e05b19110 100644 --- a/src/diffusers/models/controlnets/multicontrolnet_union.py +++ b/src/diffusers/models/controlnets/multicontrolnet_union.py @@ -66,6 +66,7 @@ def forward( attention_mask=attention_mask, added_cond_kwargs=added_cond_kwargs, cross_attention_kwargs=cross_attention_kwargs, + from_multi=True, guess_mode=guess_mode, return_dict=return_dict, )