Better SDPA unmasking implementation#29318
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
|
| if expanded_mask.dtype == torch.bool: | ||
| raise ValueError("AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor.") |
There was a problem hiding this comment.
Some models (gpt bigcode) use bool tensors, but Arthur's implem can't work for that dtype.
There was a problem hiding this comment.
can't we cast it and replace the min with 0?
There was a problem hiding this comment.
For now I expect the cast to be done in the modeling file (explicit).
ArthurZucker
left a comment
There was a problem hiding this comment.
LGTM thanks for propagating the changes
| if expanded_mask.dtype == torch.bool: | ||
| raise ValueError("AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor.") |
There was a problem hiding this comment.
can't we cast it and replace the min with 0?
As @ArthurZucker improved the unmasking for SDPA for mem-efficient code path let's do so for all archs using SDPA #27931