-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Description
Recently, we have been receiving issues from users complaining that SDPA leads to OOMs whereas xformers doesn't not.
So, to help identify the root cause of this, I started a simple benchmark to compare the timings of the different efficient implementations of attention provided by SDPA and xformers. I have restricted this to forward-only for now.
My Colab Notebook is available here: https://colab.research.google.com/drive/1Ne0YPY16G2gmr9H1eCa9iB1wG5jLwf-h?usp=sharing
The setup is simple. Let's start by defining some variables:
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32
dtype = torch.float16
device = "cuda"
query = torch.rand(
batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype
)
key = torch.rand(
batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype
)
value = torch.rand(
batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype
)And then I followed this tutorial to perform the benchmark with F.scaled_dot_product_attention(). And I extended that to use xformers.ops.memory_efficient_attention(). Refer to the Colab Notebook for more details.
Here's the gist of the results:
SDPA
[--------------------- ---------------------]
| flash | mem | math
6 threads: -----------------------------------
torch.float16 | 2.7 | 3.0 | 7.0
Times are in milliseconds (ms).xformers
[------------------- ------------------]
| cutlass | flash
6 threads: ------------------------------
torch.float16 | 2.6 | 1.7
Times are in milliseconds (ms).We see that the flash variant obtains better performance here. This small difference can compound quickly when we're dealing with dee models. Additionally, I don't think we satisfy the required conditions in the Unet2DConditionModel to be able to take advantage of flash attention. More details: #3594 (comment).
The above issue comment confirms that when using SDPA, the dispatcher uses the SDPBackend.EFFICIENT_ATTENTION variant. This is 3.0 as far as timing is concerned.
For cutlass, we have 2.6 in xformers which is still less than the above.
Triton variant of the memory-efficient implementation errors out. See: facebookresearch/xformers#769.
Tagging a few people here to hear their thoughts:
@williamberman @takuma104 @patrickvonplaten @pcuenca
Let me know if anything is unclear.