Skip to content

[Mixtral & Mistral] Add support for sdpa#28133

Merged
ArthurZucker merged 17 commits intomainfrom
add-mixtral-sdpa
Dec 21, 2023
Merged

[Mixtral & Mistral] Add support for sdpa#28133
ArthurZucker merged 17 commits intomainfrom
add-mixtral-sdpa

Conversation

@ArthurZucker
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker commented Dec 19, 2023

What does this PR do?

Adds the SDPA attention for both classes cc @younesbelkada for visibility 😉 Will help for fast LLava

@ArthurZucker ArthurZucker marked this pull request as ready for review December 20, 2023 18:05
Copy link
Copy Markdown
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !
I don"t see why sliding window attention shouldn't be supported with SDPA because the only difference vs the eager attention implementation is on the attention mask. Passing arbitrary attention masks in SDPA should be supported without any problem IMO

Comment thread src/transformers/models/mistral/modeling_mistral.py Outdated
Comment thread src/transformers/models/mistral/modeling_mistral.py
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ehuaa
Copy link
Copy Markdown
Contributor

ehuaa commented Feb 12, 2024

Thanks ! I don"t see why sliding window attention shouldn't be supported with SDPA because the only difference vs the eager attention implementation is on the attention mask. Passing arbitrary attention masks in SDPA should be supported without any problem IMO

I have the same problem here, why sdpa not support window attention? Is there any problems not been solved? @ArthurZucker

@younesbelkada
Copy link
Copy Markdown
Contributor

@ehuaa the way the window attention is implemented in Mistral original code base is by changing the attention mask to a "more custom" attention mask to not attend to tokens that are before sliding_windows. Check out more by looking into the details of this method:


The point that I tried to convey is that passing that attention mask is supported I think in SDPA so you can implicitly get SDPA + sliding window attention by just passing that correct attention mask. Let me know if this makes sense to you!

@ehuaa
Copy link
Copy Markdown
Contributor

ehuaa commented Feb 13, 2024

@ehuaa the way the window attention is implemented in Mistral original code base is by changing the attention mask to a "more custom" attention mask to not attend to tokens that are before sliding_windows. Check out more by looking into the details of this method:

The point that I tried to convey is that passing that attention mask is supported I think in SDPA so you can implicitly get SDPA + sliding window attention by just passing that correct attention mask. Let me know if this makes sense to you!

@younesbelkada Thank you for your quick reply! Your solution above can pass a custom mask to sdpa, and i think this way is the same as passing sliding_window param to this function.
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L1006-L1023
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants