Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 2 additions & 35 deletions docs/source/en/optimization/fp16.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,42 +65,11 @@ image = pipe(prompt).images[0]

</Tip>

## Sliced attention for additional memory savings

For even additional memory savings, you can use a sliced version of attention that performs the computation in steps instead of all at once.

<Tip>
Attention slicing is useful even if a batch size of just 1 is used - as long
as the model uses more than one attention head. If there is more than one
attention head the *QK^T* attention matrix can be computed sequentially for
each head which can save a significant amount of memory.
</Tip>

To perform the attention computation sequentially over each head, you only need to invoke [`~DiffusionPipeline.enable_attention_slicing`] in your pipeline before inference, like here:

```Python
import torch
from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_attention_slicing()
image = pipe(prompt).images[0]
```

There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM!


## Sliced VAE decode for larger batches

To decode large batches of images with limited VRAM, or to enable batches with 32 images or more, you can use sliced VAE decode that decodes the batch latents one image at a time.

You likely want to couple this with [`~StableDiffusionPipeline.enable_attention_slicing`] or [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.
You likely want to couple this with [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.

To perform the VAE decode one image at a time, invoke [`~StableDiffusionPipeline.enable_vae_slicing`] in your pipeline before inference. For example:

Expand All @@ -126,7 +95,7 @@ You may see a small performance boost in VAE decode on multi-image batches. Ther

Tiled VAE processing makes it possible to work with large images on limited VRAM. For example, generating 4k images in 8GB of VRAM. Tiled VAE decoder splits the image into overlapping tiles, decodes the tiles, and blends the outputs to make the final image.

You want to couple this with [`~StableDiffusionPipeline.enable_attention_slicing`] or [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.
You want to couple this with [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.

To use tiled VAE processing, invoke [`~StableDiffusionPipeline.enable_vae_tiling`] in your pipeline before inference. For example:

Expand Down Expand Up @@ -192,7 +161,6 @@ pipe = StableDiffusionPipeline.from_pretrained(

prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_sequential_cpu_offload()
pipe.enable_attention_slicing(1)

image = pipe(prompt).images[0]
```
Expand Down Expand Up @@ -241,7 +209,6 @@ pipe = StableDiffusionPipeline.from_pretrained(

prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_model_cpu_offload()
pipe.enable_attention_slicing(1)

image = pipe(prompt).images[0]
```
Expand Down
29 changes: 27 additions & 2 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1669,15 +1669,40 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor
in slices to compute attention in several steps. This is useful to save some memory in exchange for a small
speed decrease.
in slices to compute attention in several steps. For more than one attention head, the computation is performed
sequentially over each head. This is useful to save some memory in exchange for a small speed decrease.

<Tip warning={true}>

⚠️ Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch
2.0 or xFormers. These attention computations are already very memory efficient so you won't need to enable
this function. If you enable attention slicing with SDPA or xFormers, it can lead to serious slow downs!

</Tip>

Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.

Examples:

```py
>>> import torch
>>> from diffusers import StableDiffusionPipeline

>>> pipe = StableDiffusionPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5",
... torch_dtype=torch.float16,
... use_safetensors=True,
... )

>>> prompt = "a photo of an astronaut riding a horse on mars"
>>> pipe.enable_attention_slicing()
>>> image = pipe(prompt).images[0]
```
"""
self.set_attention_slice(slice_size)

Expand Down