Skip to content

[SDPA vs. xformers] Discussions on benchmarking SDPA and xformers and implications #3793

@sayakpaul

Description

@sayakpaul

Recently, we have been receiving issues from users complaining that SDPA leads to OOMs whereas xformers doesn't not.

So, to help identify the root cause of this, I started a simple benchmark to compare the timings of the different efficient implementations of attention provided by SDPA and xformers. I have restricted this to forward-only for now.

My Colab Notebook is available here: https://colab.research.google.com/drive/1Ne0YPY16G2gmr9H1eCa9iB1wG5jLwf-h?usp=sharing

The setup is simple. Let's start by defining some variables:

batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32

dtype = torch.float16
device = "cuda"

query = torch.rand(
    batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype
)
key = torch.rand(
    batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype
)
value = torch.rand(
    batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype
)

And then I followed this tutorial to perform the benchmark with F.scaled_dot_product_attention(). And I extended that to use xformers.ops.memory_efficient_attention(). Refer to the Colab Notebook for more details.

Here's the gist of the results:

SDPA

[---------------------  ---------------------]
                     |  flash  |  mem  |  math
6 threads: -----------------------------------
      torch.float16  |   2.7   |  3.0  |  7.0 

Times are in milliseconds (ms).

xformers

[-------------------  ------------------]
                     |  cutlass  |  flash
6 threads: ------------------------------
      torch.float16  |    2.6    |   1.7 

Times are in milliseconds (ms).

We see that the flash variant obtains better performance here. This small difference can compound quickly when we're dealing with dee models. Additionally, I don't think we satisfy the required conditions in the Unet2DConditionModel to be able to take advantage of flash attention. More details: #3594 (comment).

The above issue comment confirms that when using SDPA, the dispatcher uses the SDPBackend.EFFICIENT_ATTENTION variant. This is 3.0 as far as timing is concerned.

For cutlass, we have 2.6 in xformers which is still less than the above.

Triton variant of the memory-efficient implementation errors out. See: facebookresearch/xformers#769.

Tagging a few people here to hear their thoughts:

@williamberman @takuma104 @patrickvonplaten @pcuenca

Let me know if anything is unclear.

Metadata

Metadata

Assignees

No one assigned

    Labels

    staleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions