Skip to content

Conversation

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Sep 5, 2022

100% of the credit goes to the amazing thread here: basujindal/stable-diffusion#117

By doing:

pip.set_attention_chunk()

VRAM can be reduced from 4.6GB to 3.4GB at only 10% slower inference. Try it out:

import torch
from torch import autocast
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", 
    revision="fp16", 
    torch_dtype=torch.float16,
    use_auth_token=True
)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
pipe.set_attention_chunk()
with autocast("cuda"):
    image = pipe(prompt).images[0] 

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 5, 2022

The documentation is not available anymore as the PR was closed or merged.

@Inkorak
Copy link

Inkorak commented Sep 6, 2022

@patrickvonplaten Perhaps it is worth looking at this implementation, which also adapts if there is a lack of VRAM by reducing the speed of inference. But if it is not there, then it works just as fast.
My implementation for diffusers.

from diffusers.models.attention import CrossAttention
import math
import torch
from torch import einsum
from einops import rearrange
import types

def forward(self, x, context=None, mask=None):
    batch_size, sequence_length, dim = x.shape

    h = self.heads

    q = self.to_q(x)
    context = context if context is not None else x
    k = self.to_k(context)
    v = self.to_v(context)
    del context, x

    q = self.reshape_heads_to_batch_dim(q)
    k = self.reshape_heads_to_batch_dim(k)
    v = self.reshape_heads_to_batch_dim(v)

    r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)

    stats = torch.cuda.memory_stats(q.device)
    mem_total = torch.cuda.get_device_properties(0).total_memory
    mem_active = stats['active_bytes.all.current']
    mem_free = mem_total - mem_active

    mem_required = q.shape[0] * q.shape[1] * k.shape[1] * 4 * 2.5
    steps = 1

    if mem_required > mem_free:
        steps = 2**(math.ceil(math.log(mem_required / mem_free, 2)))

    slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
    for i in range(0, q.shape[1], slice_size):
        end = i + slice_size
        s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
        s1 *= self.scale

        s2 = s1.softmax(dim=-1)
        del s1

        r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
        del s2

    del q, k, v

    r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
    del r1

    return self.to_out(r2)

def optimize_attention(model):
    for module in model.modules():
        if isinstance(module, CrossAttention):
            module.forward = types.MethodType(forward, module)

from diffusers import StableDiffusionPipeline
from torch import autocast

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
pipe = pipe.to("cuda")

optimize_attention(pipe.unet)

prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
    image = pipe(prompt).images[0]  

@patrickvonplaten patrickvonplaten changed the title [WIP] Efficient Attention Efficient Attention Sep 6, 2022
@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Sep 6, 2022

Hey @Inkorak,

I think this is more or less exactly the implementation I used - just the checkpointing I think we haven't implemented yet.
Could you point out the different between your implementation and what's currently in the PR? :-)

@Inkorak
Copy link

Inkorak commented Sep 6, 2022

@patrickvonplaten As far as I understand, this implementation looks at how much VRAM is free and chunks things accordingly. It's only chunking as much as it needs to retain performance.

    stats = torch.cuda.memory_stats(q.device)
    mem_total = torch.cuda.get_device_properties(0).total_memory
    mem_active = stats['active_bytes.all.current']
    mem_free = mem_total - mem_active

    mem_required = q.shape[0] * q.shape[1] * k.shape[1] * 4 * 2.5
    steps = 1

    if mem_required > mem_free:
        steps = 2**(math.ceil(math.log(mem_required / mem_free, 2)))

    slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]

@patrickvonplaten
Copy link
Contributor Author

I see! That's smart! I think for this codebase here it's a bit too much "black magic" and requires to much complex code. We value readability, and intuitive code design a lot and I think here we have a bit too much logic which doesn't necessarily correspond to the attention mechanism itself.

In my experience, just setting the "chunk size" to half the number of attention heads works pretty well and we only loose 10% speed.

Comment on lines 39 to 41
def set_attention_chunk(self, chunk_size: Optional[Union[str, int]] = "auto"):
# set chunk_size = `None` to disable `set_attention_chunk`
if chunk_size == "auto":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about having it as a setter function, but adding another argument to __call__ isn't ideal either...
My reasoning is that we'll have multiple levels of optimization (fp16, accelerate loader, chunked softmax), and it would be great to have one place to control them all.

But maybe @patil-suraj or @pcuenca have a better idea where to put it API-wise, I'm just a bit conflicted about having a function here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm - don't really see the problem here to be honest. Not that this is more or less exactly the same as gradient_checkpointing_enable() which is also a setter function (we trade speed for memory). fp16 is also a setter function. Also I don't fully understand what you mean by "one place" for all?

Other options are:

  • putting it in the config: no-go as it's not at all related to the architecture of the model
  • forward argument to forward(): Don't like this
    • IMO parameters to the forward function should be limited to either tensors that are used or flags that change the type of output that is returned (e.g. return_dict, output_type). Optimization parameters should be setters IMO (like gradient checkpointing, half(), ....)

Copy link
Contributor

@patil-suraj patil-suraj Sep 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Patrick here. The other options are aren't really ideal, so let's go with this.

Although, not sure if it's best keep this in the pipeline, what if we do

pipe.unet.set_attention_chunk

as this is specific to the model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the point but we don't have a general good default for all unets - it's a good default for stable diffusion so I think we need a pipeline function here

@sradc
Copy link

sradc commented Sep 6, 2022

Awesome stuff!

Imho set_attention_chunk could be more intuitively named; "attention_chunk" seems a bit like a noun, which is not the case.

Best I can think of right now is: pipe.use_chunked_attention().

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Sep 6, 2022

intuitively

Naming is indeed important here and I'm also not super happy with chunked attention.
Mainly because "chunked attention" usually refers to self-attention being chunked across the time dimensions. What is really happening here is that the attention is "chunked" across the batch dimension and not the sequence dimensions.

Maybe slice_attention_enable(attention_slice_size: Optional[str, int]) ?
Also @anton-l @pcuenca @patil-suraj what do you think?

@anton-l
Copy link
Member

anton-l commented Sep 6, 2022

enable_attention_chunking or enable_chunked_attention could work as well. The chunking part sounds ok to me, since we're still working with tensor dimensions (be it batch or time dimensions)

@sradc
Copy link

sradc commented Sep 6, 2022

Good points! Another option that came to mind: pipe.enable_memory_optimized_attention(attention_slice_size: Optional[str, int]). (I'll bow out of the comments now though, and leave it to you folks!)

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this! My only main comment is should we expose the set_attention_chunk as it's very much related to the unet model. what do we think about doing pipe.unet.set_attention_chunk ?

Comment on lines 39 to 41
def set_attention_chunk(self, chunk_size: Optional[Union[str, int]] = "auto"):
# set chunk_size = `None` to disable `set_attention_chunk`
if chunk_size == "auto":
Copy link
Contributor

@patil-suraj patil-suraj Sep 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Patrick here. The other options are aren't really ideal, so let's go with this.

Although, not sure if it's best keep this in the pipeline, what if we do

pipe.unet.set_attention_chunk

as this is specific to the model.

if chunk_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
chunk_size = self.unet.config.attention_head_dim // 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

think this should be handled in the model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't agree - this is a good default for stable diffusion, not for general UNets

@patil-suraj
Copy link
Contributor

slice_attention_enable doesn't seem to be informative enough. enable_attention_chunking or enable_chunked_attention sounds good to me.

@patrickvonplaten
Copy link
Contributor Author

slice_attention_enable doesn't seem to be informative enough. enable_attention_chunking or enable_chunked_attention sounds good to me.

Want to go away from chunked_attention as people will think about long-range attention models here. Like sliced_attention better - going for this now

@pcuenca
Copy link
Member

pcuenca commented Sep 6, 2022

enable_attention_slicing sounds perfect to me!

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
@patrickvonplaten patrickvonplaten merged commit 5c4ea00 into main Sep 6, 2022
@patrickvonplaten patrickvonplaten deleted the efficient_attentino branch September 6, 2022 16:06
natolambert pushed a commit that referenced this pull request Sep 7, 2022
* up

* add tests

* correct

* up

* finish

* better naming

* Update README.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
PhaneeshB pushed a commit to nod-ai/diffusers that referenced this pull request Mar 1, 2023
1. This commit adds stable-diffusion as a part of shark web.
2. The V-diffusion model has been disabled for now as it's not
   working(will raise a different patch with fix).
3. Add standard output in the web ui.
4. Add instructions to launch the shark-web.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* up

* add tests

* correct

* up

* finish

* better naming

* Update README.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants