[ESM] Add support for sdpa.#34954
Conversation
|
Thank you for this! Rather than skipping the SDPA test, though, can you write even a simple test that uses the SDPA path? It's okay if it can't compare hidden states deeply because of issues in the ESM model, but if it could compare that output logits are similar that'd give us a lot more confidence in the SDPA code! |
Thanks for your reply, I will add relevant test cases soon. |
8f7773d to
996880a
Compare
@Rocketknight1 Hello, the sdpa inference tests for ESMFold has been added. Could you please review it? |
|
Hi @wzf03, I ran the full test suite for ESM and I'm seeing one or two test failures. Can you see if you can reproduce those locally? They may just be flaky tests, but it might also be caused by changes in this PR. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Hello @Rocketknight1, I found the test failures were due to the device mismatching of the I will report this to |
|
Hello @Rocketknight1 , I made a quick fix according to other model's test, the test cases should work normally now. |
|
Yes, looks good to me now! cc @ArthurZucker @LysandreJik for core maintainer review |
|
@ArthurZucker @LysandreJik Hello! Can you please help review this pr? |
ArthurZucker
left a comment
There was a problem hiding this comment.
Hey super sorry for the delay, waited a bit because #35235 changes the interface! Do you mind updating this PR ? Hope it's not too much of a burden! 😿
Sure, I will do it soon. |
|
Not sure if this is still active, but I have a similar PR in #38023 to add flash attention 2 to ESM |
|
The FA2 per was merged, TBH we'd rather have a small refactor to use the new |
f5d9ecc to
7400d4b
Compare
|
@ArthurZucker @Rocketknight1 Sorry for the late update. I have merged the sdpa support into the new codebase, can you help review this? |
ArthurZucker
left a comment
There was a problem hiding this comment.
thanks for updating!
| class EsmSdpaSelfAttention(EsmSelfAttention): | ||
| def __init__(self, config, position_embedding_type=None): | ||
| super().__init__(config, position_embedding_type) | ||
| self.attention_dropout_prob = config.attention_probs_dropout_prob | ||
| self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") | ||
|
|
||
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| attention_mask: Optional[torch.FloatTensor] = None, | ||
| head_mask: Optional[torch.FloatTensor] = None, | ||
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | ||
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | ||
| past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | ||
| output_attentions: Optional[bool] = False, | ||
| ) -> Tuple[torch.Tensor]: | ||
| if self.position_embedding_type not in ["absolute", "rotary"] or output_attentions or head_mask is not None: | ||
| # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. | ||
| logger.warning_once( | ||
| "EsmSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " | ||
| "non-absolute or non-rotary `position_embedding_type` or `output_attentions=True` or `head_mask`. " | ||
| "Falling back to the manual attention implementation, but specifying the manual implementation will " | ||
| "be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument " | ||
| '`attn_implementation="eager"` when loading the model.' | ||
| ) | ||
| return super().forward( | ||
| hidden_states, | ||
| attention_mask, | ||
| head_mask, | ||
| encoder_hidden_states, | ||
| encoder_attention_mask, | ||
| past_key_value, | ||
| output_attentions, | ||
| ) | ||
|
|
||
| bsz, tgt_len, _ = hidden_states.size() | ||
|
|
||
| query_layer = self.transpose_for_scores(self.query(hidden_states)) | ||
|
|
||
| # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention | ||
| # mask needs to be such that the encoder's padding tokens are not attended to. | ||
| is_cross_attention = encoder_hidden_states is not None | ||
|
|
||
| current_states = encoder_hidden_states if is_cross_attention else hidden_states | ||
| attention_mask = encoder_attention_mask if is_cross_attention else attention_mask | ||
|
|
||
| # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning | ||
| if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: | ||
| key_layer, value_layer = past_key_value | ||
| else: | ||
| key_layer = self.transpose_for_scores(self.key(current_states)) | ||
| value_layer = self.transpose_for_scores(self.value(current_states)) | ||
| if past_key_value is not None and not is_cross_attention: | ||
| key_layer = torch.cat([past_key_value[0], key_layer], dim=2) | ||
| value_layer = torch.cat([past_key_value[1], value_layer], dim=2) | ||
|
|
||
| # Scale the query for rotary embeddings | ||
| query_layer = query_layer * self.attention_head_size**-0.5 | ||
|
|
||
| if self.is_decoder: | ||
| # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. | ||
| # Further calls to cross_attention layer can then reuse all cross-attention | ||
| # key/value_states (first "if" case) |
There was a problem hiding this comment.
I think you are still missing the spot : we don't need 3 different classes anymore https://github.com/huggingface/transformers/blob/tp-cb/src/transformers/models/llama/modeling_llama.py#L249-L249
|
looks like #38751 broke the existing flash_attention_2 implementation for ESM-2 as well, so we're back to only eager being supported. |
|
cc @zucchini-nlp ! |
What does this PR do?
Add support for SDPA (scaled dot product attention) for ESM. More context in #28802 (And this pr mainly reused the code from this pr as the ESM is Bert-based model) and #28005 .
This is my first time contributing to this project, please point out if there is any mistakes.
And revert a change in #29329 as the dtype-mismatching issue for bitsandbytes is actually caused by the rotary embedding.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker