SDPA for T5 Attention#31167
Open
huseinzol05 wants to merge 5 commits intohuggingface:mainfrom
Open
Conversation
ArthurZucker
reviewed
Jun 6, 2024
Collaborator
ArthurZucker
left a comment
There was a problem hiding this comment.
Hey! Could you make sure the cis go green!? 🤗
Contributor
Author
Hi! Im sorry, what is cis? |
Collaborator
|
Hey! It is the integration tests right below this message that are all red! |
Contributor
Author
|
passed except for the quality |
5 tasks
ArthurZucker
reviewed
Oct 3, 2024
Collaborator
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks for working on this! Let's try to re-use what we have in other modeling codes to have constant standards 🤗
Comment on lines
+625
to
+631
| def shape(states): | ||
| """projection""" | ||
| return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) | ||
|
|
||
| def unshape(states): | ||
| """reshape""" | ||
| return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) |
Collaborator
There was a problem hiding this comment.
no let's remove these one liners
Comment on lines
+633
to
+658
| def project(hidden_states, proj_layer, key_value_states, past_key_value): | ||
| """projects hidden states correctly to key/query states""" | ||
| if key_value_states is None: | ||
| # self-attn | ||
| # (batch_size, n_heads, seq_length, dim_per_head) | ||
| hidden_states = shape(proj_layer(hidden_states)) | ||
| elif past_key_value is None: | ||
| # cross-attn | ||
| # (batch_size, n_heads, seq_length, dim_per_head) | ||
| hidden_states = shape(proj_layer(key_value_states)) | ||
|
|
||
| if past_key_value is not None: | ||
| if key_value_states is None: | ||
| # self-attn | ||
| # (batch_size, n_heads, key_length, dim_per_head) | ||
| hidden_states = torch.cat([past_key_value, hidden_states], dim=2) | ||
| elif past_key_value.shape[2] != key_value_states.shape[1]: | ||
| # checking that the `sequence_length` of the `past_key_value` is the same as | ||
| # the provided `key_value_states` to support prefix tuning | ||
| # cross-attn | ||
| # (batch_size, n_heads, seq_length, dim_per_head) | ||
| hidden_states = shape(proj_layer(key_value_states)) | ||
| else: | ||
| # cross-attn | ||
| hidden_states = past_key_value | ||
| return hidden_states |
Comment on lines
+660
to
+688
| # get query states | ||
| query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) | ||
|
|
||
| # get key/value states | ||
| key_states = project( | ||
| hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None | ||
| ) | ||
| value_states = project( | ||
| hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None | ||
| ) | ||
|
|
||
| if position_bias is None: | ||
| if not self.has_relative_attention_bias: | ||
| position_bias = torch.zeros( | ||
| (1, self.n_heads, real_seq_length, key_length), | ||
| device=query_states.device, dtype=query_states.dtype | ||
| ) | ||
| if self.gradient_checkpointing and self.training: | ||
| position_bias.requires_grad = True | ||
| else: | ||
| position_bias = self.compute_bias(real_seq_length, key_length, device=query_states.device) | ||
|
|
||
| # if key and values are already calculated | ||
| # we want only the last query position bias | ||
| if past_key_value is not None: | ||
| position_bias = position_bias[:, :, -hidden_states.size(1) :, :] | ||
|
|
||
| if mask is not None: | ||
| position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) |
Collaborator
There was a problem hiding this comment.
the closer we are from UMT5 or WhisperSDPAAttention, the better! 🤗
Contributor
Author
There was a problem hiding this comment.
noted, i will try to patch it
|
@huseinzol05 @ArthurZucker how are we feeling about these changes now? |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
SDPA for T5 Attention