Skip to content

StableDiffusionXLControlNetPipeline NOT Support Guess Mode #4709

@chuxing

Description

@chuxing

Describe the bug

when use StableDiffusionXLControlNetPipeline and guess_mode is True, the pipeline not work and the error is :
RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [20, 64] but got: [10, 64].

Reproduction

controlnet_model = ControlNetModel.from_pretrained ("")
sd_pipe = StableDiffusionXLControlNetPipeline.from_pretrained("", controlnet=controlnet_model,torch_dtype=torch.float16, variant="fp16", use_safetensors=True)

image = sd_pipe(prompt=prompt, negative_prompt=negative_prompt, image=control_image, num_inference_steps=30, guess_mode=True).images[0]

Logs

File /opt/conda/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File ~/.local/lib/python3.9/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py:936, in StableDiffusionXLControlNetPipeline.__call__(self, prompt, prompt_2, image, height, width, num_inference_steps, guidance_scale, negative_prompt, negative_prompt_2, num_images_per_prompt, eta, generator, latents, prompt_embeds, negative_prompt_embeds, output_type, return_dict, callback, callback_steps, cross_attention_kwargs, controlnet_conditioning_scale, guess_mode, control_guidance_start, control_guidance_end, original_size, crops_coords_top_left, target_size)
    933     cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
    935 added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
--> 936 down_block_res_samples, mid_block_res_sample = self.controlnet(
    937     control_model_input,
    938     t,
    939     encoder_hidden_states=controlnet_prompt_embeds,
    940     controlnet_cond=image,
    941     conditioning_scale=cond_scale,
    942     guess_mode=guess_mode,
    943     added_cond_kwargs=added_cond_kwargs,
    944     return_dict=False,
    945 )
    947 if guess_mode and do_classifier_free_guidance:
    948     # Infered ControlNet only for the conditional batch.
    949     # To apply the output of ControlNet to both the unconditional and conditional batches,
    950     # add 0 to the unconditional batch to keep it unchanged.
    951     down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.local/lib/python3.9/site-packages/diffusers/models/controlnet.py:760, in ControlNetModel.forward(self, sample, timestep, encoder_hidden_states, controlnet_cond, conditioning_scale, class_labels, timestep_cond, attention_mask, added_cond_kwargs, cross_attention_kwargs, guess_mode, return_dict)
    758 for downsample_block in self.down_blocks:
    759     if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
--> 760         sample, res_samples = downsample_block(
    761             hidden_states=sample,
    762             temb=emb,
    763             encoder_hidden_states=encoder_hidden_states,
    764             attention_mask=attention_mask,
    765             cross_attention_kwargs=cross_attention_kwargs,
    766         )
    767     else:
    768         sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.local/lib/python3.9/site-packages/diffusers/models/unet_2d_blocks.py:996, in CrossAttnDownBlock2D.forward(self, hidden_states, temb, encoder_hidden_states, attention_mask, cross_attention_kwargs, encoder_attention_mask, additional_residuals)
    994 else:
    995     hidden_states = resnet(hidden_states, temb)
--> 996     hidden_states = attn(
    997         hidden_states,
    998         encoder_hidden_states=encoder_hidden_states,
    999         cross_attention_kwargs=cross_attention_kwargs,
   1000         attention_mask=attention_mask,
   1001         encoder_attention_mask=encoder_attention_mask,
   1002         return_dict=False,
   1003     )[0]
   1005 # apply additional residuals to the output of the last pair of resnet and attention blocks
   1006 if i == len(blocks) - 1 and additional_residuals is not None:

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.local/lib/python3.9/site-packages/diffusers/models/transformer_2d.py:292, in Transformer2DModel.forward(self, hidden_states, encoder_hidden_states, timestep, class_labels, cross_attention_kwargs, attention_mask, encoder_attention_mask, return_dict)
    290 # 2. Blocks
    291 for block in self.transformer_blocks:
--> 292     hidden_states = block(
    293         hidden_states,
    294         attention_mask=attention_mask,
    295         encoder_hidden_states=encoder_hidden_states,
    296         encoder_attention_mask=encoder_attention_mask,
    297         timestep=timestep,
    298         cross_attention_kwargs=cross_attention_kwargs,
    299         class_labels=class_labels,
    300     )
    302 # 3. Output
    303 if self.is_input_continuous:

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.local/lib/python3.9/site-packages/diffusers/models/attention.py:171, in BasicTransformerBlock.forward(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels)
    166 if self.attn2 is not None:
    167     norm_hidden_states = (
    168         self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
    169     )
--> 171     attn_output = self.attn2(
    172         norm_hidden_states,
    173         encoder_hidden_states=encoder_hidden_states,
    174         attention_mask=encoder_attention_mask,
    175         **cross_attention_kwargs,
    176     )
    177     hidden_states = attn_output + hidden_states
    179 # 3. Feed-forward

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.local/lib/python3.9/site-packages/diffusers/models/attention_processor.py:322, in Attention.forward(self, hidden_states, encoder_hidden_states, attention_mask, **cross_attention_kwargs)
    318 def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
    319     # The `Attention` class can call different attention processors / attention functions
    320     # here we simply pass along all tensors to the selected processor class
    321     # For standard processors that are defined here, `**cross_attention_kwargs` is empty
--> 322     return self.processor(
    323         self,
    324         hidden_states,
    325         encoder_hidden_states=encoder_hidden_states,
    326         attention_mask=attention_mask,
    327         **cross_attention_kwargs,
    328     )

File ~/.local/lib/python3.9/site-packages/diffusers/models/attention_processor.py:489, in AttnProcessor.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb)
    486 key = attn.head_to_batch_dim(key)
    487 value = attn.head_to_batch_dim(value)
--> 489 attention_probs = attn.get_attention_scores(query, key, attention_mask)
    490 hidden_states = torch.bmm(attention_probs, value)
    491 hidden_states = attn.batch_to_head_dim(hidden_states)

File ~/.local/lib/python3.9/site-packages/diffusers/models/attention_processor.py:363, in Attention.get_attention_scores(self, query, key, attention_mask)
    360     baddbmm_input = attention_mask
    361     beta = 1
--> 363 attention_scores = torch.baddbmm(
    364     baddbmm_input,
    365     query,
    366     key.transpose(-1, -2),
    367     beta=beta,
    368     alpha=self.scale,
    369 )
    370 del baddbmm_input
    372 if self.upcast_softmax:

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [20, 64] but got: [10, 64].

System Info

Name: diffusers
Version: 0.19.1
Summary: Diffusers
Home-page: https://github.com/huggingface/diffusers
Author: The HuggingFace team
Author-email: patrick@huggingface.co
License: Apache
Location: /home/jovyan/.local/lib/python3.9/site-packages
Requires: filelock, huggingface-hub, importlib-metadata, numpy, Pillow, regex, requests, safetensors

Who can help?

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions