Use new attention API for LayoutLMv3 (SDPA, Flash Attn v2 support)#35469
Use new attention API for LayoutLMv3 (SDPA, Flash Attn v2 support)#35469stancld wants to merge 2 commits intohuggingface:mainfrom
Conversation
c5de661 to
923cdea
Compare
|
Just pinging @NielsRogge or @ArthurZucker .) |
There was a problem hiding this comment.
Thanks for adding a test, but we have a common test test_eager_matches_sdpa_inference, so no need for another one (it's enabled once _supports_sdpa = True set). But it's great to see this one pass
923cdea to
7853314
Compare
|
@qubvel Thanks for the notes :] Will run some speed benchmarks with various seq lens & batch sizes tonight and add to the docs :] |
| PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function | ||
| encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the | ||
| [official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) | ||
| or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) | ||
| page for more information. |
There was a problem hiding this comment.
| PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function | |
| encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the | |
| [official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) | |
| or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) | |
| page for more information. | |
| PyTorch includes a native scaled dot-product attention (SDPA) operator as part of [torch.nn.functional](https://pytorch.org/docs/stable/nn.functional.html). This function | |
| encompasses several memory-efficient attention implementations that can be applied depending on the inputs and hardware. See the | |
| [official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) | |
| or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) | |
| page for more information. |
| SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set | ||
| `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. | ||
|
|
||
| ``` | ||
| from transformers import LayoutLMv3Model | ||
|
|
||
| model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base", torch_dtype=torch.float16, attn_implementation="sdpa") | ||
| ... | ||
| ``` | ||
|
|
||
| For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). |
There was a problem hiding this comment.
| SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set | |
| `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. | |
| ``` | |
| from transformers import LayoutLMv3Model | |
| model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base", torch_dtype=torch.float16, attn_implementation="sdpa") | |
| ... | |
| ``` | |
| For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). | |
| SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set | |
| `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA. | |
| For the best speedups, we recommend loading the model in half-precision (`torch.float16` or `torch.bfloat16`). | |
| ```py | |
| from transformers import LayoutLMv3Model | |
| model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base", torch_dtype=torch.float16, attn_implementation="sdpa") |
7853314 to
f324038
Compare
ArthurZucker
left a comment
There was a problem hiding this comment.
Very welcome!
Sorry for delaying the review, I saw that it was not clean (unrealted example files)
If you want to push this through let's maybe use the latest api for attention interface! (see modeling llama's attention layer!)
e469f2d to
a5828b2
Compare
|
@ArthurZucker It required rebase, dunno why it looked like those unrelated files were touched. Will check the new API. |
a5828b2 to
d9c6c88
Compare
d9c6c88 to
37a9581
Compare
060b80d to
b01d58f
Compare
b01d58f to
7b51a17
Compare
|
@ArthurZucker Flash Attn impl is still broken here. Will have a look as time allows .] |
7b51a17 to
172f327
Compare
| self.self = nn.ModuleDict( | ||
| { | ||
| "query": nn.Linear(config.hidden_size, config.num_attention_heads * self.attention_head_size), | ||
| "key": nn.Linear(config.hidden_size, config.num_attention_heads * self.attention_head_size), | ||
| "value": nn.Linear(config.hidden_size, config.num_attention_heads * self.attention_head_size), | ||
| } |
There was a problem hiding this comment.
guessing we cannot remove this for BC! OK 🤗
| ) | ||
|
|
||
|
|
||
| def _cogview_attention(attention_scores: torch.Tensor, alpha: Union[int, float] = 32) -> torch.Tensor: |
There was a problem hiding this comment.
we dont'really need a separate function but if we keep it place it at the top please
|
On nit solve conflicts and good to go! |
What does this PR do?
Closes #35467.
Performance benchmark
Speed & memory req consumption on a token classification ntraining of LayoutLMv3-like model with multilingual support, various auxiliary tasks, masked language modelling.
GPU: 1x A100 80 GB
Batch size: 16, Accumulated gradient batches: 8
Overall, ~50% speed-up and memory reqs reduction is observed.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
cc: @ArthurZucker