Skip to content

Use new attention API for LayoutLMv3 (SDPA, Flash Attn v2 support)#35469

Open
stancld wants to merge 2 commits intohuggingface:mainfrom
stancld:ds/feat/layoutlmv3-flash-attn
Open

Use new attention API for LayoutLMv3 (SDPA, Flash Attn v2 support)#35469
stancld wants to merge 2 commits intohuggingface:mainfrom
stancld:ds/feat/layoutlmv3-flash-attn

Conversation

@stancld
Copy link
Copy Markdown
Contributor

@stancld stancld commented Dec 31, 2024

What does this PR do?

Closes #35467.

Performance benchmark

Speed & memory req consumption on a token classification ntraining of LayoutLMv3-like model with multilingual support, various auxiliary tasks, masked language modelling.

GPU: 1x A100 80 GB
Batch size: 16, Accumulated gradient batches: 8

Impl. Speed Peak memory
Eager ~2.0 it/s 66.7 Gi
SDPA ~3.0 it/s 47.2 Gi

Overall, ~50% speed-up and memory reqs reduction is observed.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

cc: @ArthurZucker

@stancld stancld changed the title [WIP] Add SDPA support for LayoutLMv3 model Add SDPA support for LayoutLMv3 model Dec 31, 2024
@stancld stancld force-pushed the ds/feat/layoutlmv3-flash-attn branch 5 times, most recently from c5de661 to 923cdea Compare January 2, 2025 10:03
@stancld
Copy link
Copy Markdown
Contributor Author

stancld commented Jan 14, 2025

Just pinging @NielsRogge or @ArthurZucker .)

@NielsRogge NielsRogge requested a review from qubvel January 17, 2025 08:28
Copy link
Copy Markdown
Contributor

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @stancld! Thanks for working on this, looks great 🤗 It would be nice to add some benchmark results to docs to ensure SDPA works after than eager one

Comment on lines 382 to 402
Copy link
Copy Markdown
Contributor

@qubvel qubvel Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding a test, but we have a common test test_eager_matches_sdpa_inference, so no need for another one (it's enabled once _supports_sdpa = True set). But it's great to see this one pass

@stancld
Copy link
Copy Markdown
Contributor Author

stancld commented Jan 17, 2025

@qubvel Thanks for the notes :] Will run some speed benchmarks with various seq lens & batch sizes tonight and add to the docs :]

Copy link
Copy Markdown
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Comment thread docs/source/en/model_doc/layoutlmv3.md Outdated
Comment on lines +46 to +50
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of [torch.nn.functional](https://pytorch.org/docs/stable/nn.functional.html). This function
encompasses several memory-efficient attention implementations that can be applied depending on the inputs and hardware. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

Comment thread docs/source/en/model_doc/layoutlmv3.md Outdated
Comment on lines +52 to +62
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```
from transformers import LayoutLMv3Model

model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base", torch_dtype=torch.float16, attn_implementation="sdpa")
...
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import LayoutLMv3Model
model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base", torch_dtype=torch.float16, attn_implementation="sdpa")
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA.
For the best speedups, we recommend loading the model in half-precision (`torch.float16` or `torch.bfloat16`).
```py
from transformers import LayoutLMv3Model
model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base", torch_dtype=torch.float16, attn_implementation="sdpa")

@stancld stancld force-pushed the ds/feat/layoutlmv3-flash-attn branch from 7853314 to f324038 Compare January 22, 2025 10:34
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very welcome!
Sorry for delaying the review, I saw that it was not clean (unrealted example files)
If you want to push this through let's maybe use the latest api for attention interface! (see modeling llama's attention layer!)

@stancld stancld force-pushed the ds/feat/layoutlmv3-flash-attn branch from e469f2d to a5828b2 Compare February 17, 2025 13:28
@stancld
Copy link
Copy Markdown
Contributor Author

stancld commented Feb 17, 2025

@ArthurZucker It required rebase, dunno why it looked like those unrelated files were touched.

Will check the new API.

@stancld stancld force-pushed the ds/feat/layoutlmv3-flash-attn branch from a5828b2 to d9c6c88 Compare February 17, 2025 14:11
@stancld stancld changed the title Add SDPA support for LayoutLMv3 model Use new attention API for LayoutLMv3 (SDPA, Flash Attn v2 support) Feb 17, 2025
@stancld stancld force-pushed the ds/feat/layoutlmv3-flash-attn branch from d9c6c88 to 37a9581 Compare February 17, 2025 14:14
@stancld stancld changed the title Use new attention API for LayoutLMv3 (SDPA, Flash Attn v2 support) Draft: Use new attention API for LayoutLMv3 (SDPA, Flash Attn v2 support) Feb 17, 2025
@stancld stancld force-pushed the ds/feat/layoutlmv3-flash-attn branch 3 times, most recently from 060b80d to b01d58f Compare February 17, 2025 15:02
@stancld stancld changed the title Draft: Use new attention API for LayoutLMv3 (SDPA, Flash Attn v2 support) Use new attention API for LayoutLMv3 (SDPA, Flash Attn v2 support) Feb 17, 2025
@stancld stancld force-pushed the ds/feat/layoutlmv3-flash-attn branch from b01d58f to 7b51a17 Compare February 17, 2025 15:24
@stancld stancld changed the title Use new attention API for LayoutLMv3 (SDPA, Flash Attn v2 support) Draft: Use new attention API for LayoutLMv3 (SDPA, Flash Attn v2 support) Feb 17, 2025
@stancld
Copy link
Copy Markdown
Contributor Author

stancld commented Feb 17, 2025

@ArthurZucker Flash Attn impl is still broken here. Will have a look as time allows .]

@stancld stancld force-pushed the ds/feat/layoutlmv3-flash-attn branch from 7b51a17 to 172f327 Compare February 17, 2025 18:56
@stancld stancld changed the title Draft: Use new attention API for LayoutLMv3 (SDPA, Flash Attn v2 support) Use new attention API for LayoutLMv3 (SDPA, Flash Attn v2 support) Feb 17, 2025
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

Comment on lines +465 to +470
self.self = nn.ModuleDict(
{
"query": nn.Linear(config.hidden_size, config.num_attention_heads * self.attention_head_size),
"key": nn.Linear(config.hidden_size, config.num_attention_heads * self.attention_head_size),
"value": nn.Linear(config.hidden_size, config.num_attention_heads * self.attention_head_size),
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

guessing we cannot remove this for BC! OK 🤗

)


def _cogview_attention(attention_scores: torch.Tensor, alpha: Union[int, float] = 32) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we dont'really need a separate function but if we keep it place it at the top please

@ArthurZucker
Copy link
Copy Markdown
Collaborator

On nit solve conflicts and good to go!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support SDPA & Flash Attention 2 for LayoutLMv3

4 participants