[Attn Masks] Non-vmap default for attention masks#41852
[Attn Masks] Non-vmap default for attention masks#41852vasqu merged 17 commits intohuggingface:mainfrom
Attn Masks] Non-vmap default for attention masks#41852Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
WIP][Masking] Non-vmap default for attention masksAttn Masks] Non-vmap default for attention masks
| return cache | ||
|
|
||
|
|
||
| def sdpa_mask_without_vmap( |
There was a problem hiding this comment.
No longer needed as vmap was the reason we needed this workaround in the first place
| NOTE: It is important to keep an index-based version for non-vmap expansion. | ||
| """ | ||
| return q_idx.new_ones((), dtype=torch.bool) | ||
| return q_idx >= 0 |
There was a problem hiding this comment.
As noted above, for non-vmap we need this as index based version
| causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True) | ||
| return causal_mask | ||
|
|
||
| attention_mask = attention_mask | torch.all(~attention_mask, dim=-1, keepdim=True) |
There was a problem hiding this comment.
I encountered issues with the inplace version where we'd need a clone (e.g. when using swa). This is safer
ArthurZucker
left a comment
There was a problem hiding this comment.
can we add a test to default test no graph break on this?
|
Hi @vasqu . Anything blocks merge? |
Cyrilvallez
left a comment
There was a problem hiding this comment.
Very very nice, this solves quite a lot of different issues at the same time! I'm very happy to avoid special handling for export! Very clever use of broadcasting from the optimum team, I did not know we could simply do such things! Thanks a lot for upstreaming directly to us.
Do you mind expanding a bit more on what are the limitations of the broadcasting approach here for posterity? Is it only index-based operations as you mention on the comments, or are there more subtle things?
Nothing I'm aware of, the only condition is to write the mask_function as a comparison between the indexes (and constants). |
|
Merging this then! Let's see what crazy masks come up in the future; for now the "mask hypothesis" holds 😆 |
|
@jiqing-feng was mainly blocked by me being out last week ;) |
* atmpt 1 * fixup masking to work correctly with old torch * few changes to make things a bit more cleaner * oopsie * fix integer overflow on bidirectional masks via indexing fn * rm executorch workarounds --> still need to handle on sliding etc fns properly * typo * docs, fix older torch inplace issue, proper kwarg handling * chunked works with non vmap and older torch, add warning on non guaranteed masks * lift unnecessary restriction on older torch * simplify a few things, restrict torch < 2.6 to non-vmap (for now) * try fix * remove unnecessary slicing logic * remove legacy func * harmonize slightly --------- Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
Non-vmap creation of masks. These work with all our base masks and we only default back to vmap when using patterns we cannot guarantee (i.e. additional and/or masks).
Note:
Fixes #41639
cc @jiqing-feng @IlyasMoutawwakil