Skip to content

[OWL-VIT] Added sdpa attention#28818

Closed
nileshkokane01 wants to merge 10 commits intohuggingface:mainfrom
nileshkokane01:sdpa_for_OWL_ViT
Closed

[OWL-VIT] Added sdpa attention#28818
nileshkokane01 wants to merge 10 commits intohuggingface:mainfrom
nileshkokane01:sdpa_for_OWL_ViT

Conversation

@nileshkokane01
Copy link
Copy Markdown
Contributor

@nileshkokane01 nileshkokane01 commented Feb 1, 2024

What does this PR do?

This PR add sdpa attention for OWL-ViT.

Fixes #28103

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@NielsRogge @younesbelkada
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@nileshkokane01
Copy link
Copy Markdown
Contributor Author

I have added an initial draft for sdpa attention, but I guess, more changes are needed as the OWL-ViT is a bit different compared to llama or Mistral.

Can you please point me out a similar model close to OWL-ViT or rather let me know the additional changes required in the file. Also, casual_attention_mask is not handled correctly - have no clue how to handle. Additionally, a corresponding test case is also necessary.

@ArthurZucker
Copy link
Copy Markdown
Collaborator

fyi @younesbelkada

Copy link
Copy Markdown
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work, thanks ! I left few nits, and a suggestion to fix the failing CI
Could you follow the same logic as here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1032-L1045 for preparing the attention mask for SDPA ?
Also similarly as Llama, could you add _supports_sdpa=True in OwlViTPreTrainedModel ? That way the tests would be triggered (e.g.:

)

Comment thread src/transformers/models/owlvit/modeling_owlvit.py Outdated
Comment thread src/transformers/models/owlvit/modeling_owlvit.py Outdated
Copy link
Copy Markdown
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @nileshkokane01
Thanks a lot for your hardwork ! It looks much cleaner ! We're almost there - it seems some tests are failing:

FAILED tests/models/owlvit/test_modeling_owlvit.py::OwlViTTextModelTest::test_model_outputs_equivalence - TypeError: _prepare_4d_causal_attention_mask_for_sdpa() missing 1 required positional argument: 'past_key_values_length'

Are you able to repro these failures locally? I think the fix should be to hardcode past_key_values_length to 0 during the call of that method as OwlViT is not a generative text model, hence does not use caching mechanism

@nileshkokane01
Copy link
Copy Markdown
Contributor Author

@younesbelkada ,
I'm trying to solve the errors. I'll let you know when its all ready or if I need any assistance.

@nileshkokane01
Copy link
Copy Markdown
Contributor Author

@younesbelkada ,

I get the following error since the batch size is dropped, and therefore the dimensionality is not matching. Any clues ?

[192, 16, 16] doesn't match the broadcast shape [48, 192, 16 ,16]

Also causal_attention_mask is not used at all in sdpa; don't know how to handle it on below line.

https://github.com/nileshkokane01/transformers/blob/sdpa_for_OWL_ViT/src/transformers/models/owlvit/modeling_owlvit.py#L484

@nileshkokane01
Copy link
Copy Markdown
Contributor Author

@younesbelkada ,

I sought of tried to fix the dimensionality mismatch for batch size , but couldn't figure out. Any clue ?

RuntimeError: output with shape 
[192, 16, 16] doesn't match the broadcast shape [48, 192, 16, 16]

with these 11 test seems to fail.

@younesbelkada
Copy link
Copy Markdown
Contributor

Hi @nileshkokane01
Thanks for getting back ! For that I need to deep dive into your branch and try to fix things, I will do that in the next days 🙏

@amyeroberts amyeroberts changed the title Added sdpa attention [OWL-VIT] Added sdpa attention Feb 19, 2024
@huggingface huggingface deleted a comment from github-actions Bot Mar 15, 2024
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 9, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

OWL-VIT Vision Foundation Model deployment in the edge cases - Need SDPA support for OWL-ViT Model optimization for Edge Deployment

3 participants