[OWL-VIT] Added sdpa attention#28818
[OWL-VIT] Added sdpa attention#28818nileshkokane01 wants to merge 10 commits intohuggingface:mainfrom
OWL-VIT] Added sdpa attention#28818Conversation
|
I have added an initial draft for sdpa attention, but I guess, more changes are needed as the OWL-ViT is a bit different compared to llama or Mistral. Can you please point me out a similar model close to OWL-ViT or rather let me know the additional changes required in the file. Also, casual_attention_mask is not handled correctly - have no clue how to handle. Additionally, a corresponding test case is also necessary. |
|
fyi @younesbelkada |
younesbelkada
left a comment
There was a problem hiding this comment.
Good work, thanks ! I left few nits, and a suggestion to fix the failing CI
Could you follow the same logic as here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1032-L1045 for preparing the attention mask for SDPA ?
Also similarly as Llama, could you add _supports_sdpa=True in OwlViTPreTrainedModel ? That way the tests would be triggered (e.g.:
06bdf84 to
4fb902e
Compare
younesbelkada
left a comment
There was a problem hiding this comment.
Hi @nileshkokane01
Thanks a lot for your hardwork ! It looks much cleaner ! We're almost there - it seems some tests are failing:
FAILED tests/models/owlvit/test_modeling_owlvit.py::OwlViTTextModelTest::test_model_outputs_equivalence - TypeError: _prepare_4d_causal_attention_mask_for_sdpa() missing 1 required positional argument: 'past_key_values_length'Are you able to repro these failures locally? I think the fix should be to hardcode past_key_values_length to 0 during the call of that method as OwlViT is not a generative text model, hence does not use caching mechanism
|
@younesbelkada , |
|
I get the following error since the batch size is dropped, and therefore the dimensionality is not matching. Any clues ?
Also causal_attention_mask is not used at all in sdpa; don't know how to handle it on below line. |
|
I sought of tried to fix the dimensionality mismatch for batch size , but couldn't figure out. Any clue ? RuntimeError: output with shape
[192, 16, 16] doesn't match the broadcast shape [48, 192, 16, 16]with these 11 test seems to fail. |
|
Hi @nileshkokane01 |
61fcbce to
3c056eb
Compare
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
This PR add sdpa attention for OWL-ViT.
Fixes #28103
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@NielsRogge @younesbelkada
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.