Skip to content

StableDiffusion3 pipeline RuntimeError when using prompt_embeds #10712

@pjjajal

Description

@pjjajal

Describe the bug

StableDiffusion3 pipeline throws a RuntimeError when using prompt_embeds in lieu of prompt when using num_images_per_prompt > 1.

I am attempting to generate images using the StableDiffusion3 pipeline with some precomputed prompt embeddings. The prompt embeddings using the .encode_prompt(...) method of the pipeline and are passed to the call of the pipeline. Passing these encoded prompts to the pipeline leads to a Runtime error when:

  • num_images_per_prompt >1 for both the .encode_prompt(...) and the __call__(...).
  • num_images_per_prompt=1 for .encode_prompt(...) and num_images_per_prompt >1 for the __call__(...).

The StableDiffusionXL pipeline does not have these errors.

Reproduction

StableDiffusion3 Failing Cases

encode_prompt num_images_per_prompt>1 and call num_images_per_prompt>1

The code for this failing case is below:

import torch
from diffusers import DiffusionPipeline

model_name = "stabilityai/stable-diffusion-3.5-medium"
pipe = DiffusionPipeline.from_pretrained(
    model_name, torch_dtype=torch.float16
).to("cuda")

# encode the prompts
(
    prompt_embeds,
    negative_prompt_embeds,
    pooled_prompt_embeds,
    negative_pooled_prompt_embeds,
) = pipe.encode_prompt(
    prompt="A painting of a cat",
    prompt_2=None,
    prompt_3=None,
    device="cuda",
    do_classifier_free_guidance=True,
    num_images_per_prompt=2, # NOTE
)

# sample (generate) from the diffusion model.
with torch.inference_mode():
    out = pipe(
        height=64, # this is set small for speeding up testing
        width=64, # this is set small for speeding up testing
        num_images_per_prompt=2,  # NOTE
        num_inference_steps=1,
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
        generator=torch.Generator(0)
    )

encode_prompt num_images_per_prompt=1 and call num_images_per_prompt>1

import torch
from diffusers import DiffusionPipeline

model_name = "stabilityai/stable-diffusion-3.5-medium"
pipe = DiffusionPipeline.from_pretrained(
    model_name, torch_dtype=torch.float16
).to("cuda")

# encode the prompts
(
    prompt_embeds,
    negative_prompt_embeds,
    pooled_prompt_embeds,
    negative_pooled_prompt_embeds,
) = pipe.encode_prompt(
    prompt="A painting of a cat",
    prompt_2=None,
    prompt_3=None,
    device="cuda",
    do_classifier_free_guidance=True,
    num_images_per_prompt=1, # NOTE
)

# sample (generate) from the diffusion model.
with torch.inference_mode():
    out = pipe(
        height=64, # this is set small for speeding up testing
        width=64, # this is set small for speeding up testing
        num_images_per_prompt=2,  # NOTE
        num_inference_steps=1,
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
        generator=torch.Generator(0)
    )

Expected Behaviour (StableDiffusionXL pipeline)

encode_prompt num_images_per_prompt=1 and call num_images_per_prompt>1

import torch
from diffusers import DiffusionPipeline

model_name = "stabilityai/sdxl-turbo"
pipe = DiffusionPipeline.from_pretrained(
    model_name, torch_dtype=torch.float16
).to("cuda")
pipe.enable_xformers_memory_efficient_attention()

(
    prompt_embeds,
    negative_prompt_embeds,
    pooled_prompt_embeds,
    negative_pooled_prompt_embeds,
) = pipe.encode_prompt(
    prompt="A painting of a cat",
    device="cuda",
    do_classifier_free_guidance=True,
    num_images_per_prompt=1,
)

with torch.inference_mode():
    out = pipe(
        # prompt="Cat",
        height=64,
        width=64,
        num_images_per_prompt=2,
        num_inference_steps=1,
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
        generator=torch.Generator(0)
    )
print(len(out.images)) # this returns 2.

encode_prompt num_images_per_prompt>1 and call num_images_per_prompt>1

import torch
from diffusers import DiffusionPipeline

model_name = "stabilityai/sdxl-turbo"
pipe = DiffusionPipeline.from_pretrained(
    model_name, torch_dtype=torch.float16
).to("cuda")
pipe.enable_xformers_memory_efficient_attention()

(
    prompt_embeds,
    negative_prompt_embeds,
    pooled_prompt_embeds,
    negative_pooled_prompt_embeds,
) = pipe.encode_prompt(
    prompt="A painting of a cat",
    device="cuda",
    do_classifier_free_guidance=True,
    num_images_per_prompt=2,
)

with torch.inference_mode():
    out = pipe(
        # prompt="Cat",
        height=64,
        width=64,
        num_images_per_prompt=2,
        num_inference_steps=1,
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
        generator=torch.Generator(0)
    )
print(len(out.images)) # this returns 4.

Logs

Traceback (most recent call last):
  File "/home/jajal/research/diffusion-trajectory/sd3.py", line 26, in <module>
    out = pipe(
          ^^^^^
  File "/home/jajal/mambaforge/envs/diff-traf/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jajal/research/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py", line 1060, in __call__
    noise_pred = self.transformer(
                 ^^^^^^^^^^^^^^^^^
  File "/home/jajal/mambaforge/envs/diff-traf/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jajal/mambaforge/envs/diff-traf/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jajal/research/diffusers/src/diffusers/models/transformers/transformer_sd3.py", line 389, in forward
    temb = self.time_text_embed(timestep, pooled_projections)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jajal/mambaforge/envs/diff-traf/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jajal/mambaforge/envs/diff-traf/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jajal/research/diffusers/src/diffusers/models/embeddings.py", line 1606, in forward
    conditioning = timesteps_emb + pooled_projections
                   ~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~
RuntimeError: The size of tensor a (8) must match the size of tensor b (4) at non-singleton dimension 0

System Info

  • 🤗 Diffusers version: 0.33.0.dev0
  • Platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.12.8
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.28.1
  • Transformers version: 4.48.2
  • Accelerate version: 1.3.0
  • PEFT version: not installed
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.2
  • xFormers version: 0.0.29.post2
  • Accelerator: NVIDIA GeForce RTX 4070 Ti, 12282 MiB
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

@yiyixuxu @sayakpaul

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions