[ESM] support attention API#40370
Conversation
|
run-slow: esm |
|
run-slow: esm |
|
This comment contains run-slow, running the specified jobs: models: ['models/esm'] |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
run-slow: esm, evolla |
|
This comment contains run-slow, running the specified jobs: models: ['models/esm', 'models/evolla'] |
| if self.config._attn_implementation != "flash_attention_2": | ||
| batch_size, seq_length = inputs_embeds.shape[:-1] | ||
| if attention_mask is None: | ||
| attention_mask = torch.ones(((batch_size, seq_length)), device=inputs_embeds.device) | ||
|
|
||
| else: | ||
| # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | ||
| # ourselves in which case we just need to make it broadcastable to all heads. | ||
| extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) | ||
| attention_mask: torch.Tensor = self.get_extended_attention_mask( | ||
| attention_mask, input_shape=(batch_size, seq_length) | ||
| ) | ||
|
|
||
| # If a 2D or 3D attention mask is provided for the cross-attention | ||
| # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] | ||
| if self.config.is_decoder and encoder_hidden_states is not None: | ||
| encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() | ||
| encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) | ||
| if encoder_attention_mask is None: | ||
| encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) | ||
| encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) | ||
| encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) | ||
| else: | ||
| encoder_extended_attention_mask = None |
There was a problem hiding this comment.
Can we align mask creation to the same as in #38301
It will make refactoring easier and the mask creations are "more proven".
There was a problem hiding this comment.
yeah, also thought of it, but after seeing it supports non-causal mask as well I think we need a cleaner approach for that in the future. Some kind of a small function that would decide which mask to construct and return it
There was a problem hiding this comment.
Not meaning the attention mask interface yet as that definitely needs an update for non-causal variants :D thinking of
transformers/src/transformers/models/bert/modeling_bert.py
Lines 990 to 1034 in e0f1e83
(without the checks on the dims). It would make it easy to remove all those functions later on if we keep it consistent across models that have yet to get the mask interface. (And I have tested it quite thoroughly on all attention variations)
|
run-slow: esm, evolla |
|
This comment contains run-slow, running the specified jobs: models: ['models/esm', 'models/evolla'] |
vasqu
left a comment
There was a problem hiding this comment.
Oops, commented on the modular generated file but should be carried over to esm - my bad ^^'
vasqu
left a comment
There was a problem hiding this comment.
Just one small nit (order of relative scaling) and could you change the mask creation per https://github.com/huggingface/transformers/pull/40370/files#r2297912836
|
run-slow: esm, evolla |
|
This comment contains run-slow, running the specified jobs: models: ['models/esm', 'models/evolla'] |
|
run-slow: esm, evolla |
| self.scaling = 1.0 # For BC we apply scaling before RoPE | ||
| self.is_decoder = config.is_decoder | ||
| self.layer_idx = layer_idx | ||
| self.is_causal = self.is_decoder # used only in FA2/FA3 |
There was a problem hiding this comment.
Sdpa uses this as well :D this is probably incorrect in case of cross-attention, can you add an optional argument here instead
There was a problem hiding this comment.
Ah right, it also has cross attention. Copied from current ESM attention assuming it would be working on main branch
|
[For maintainers] Suggested jobs to run (before merge) run-slow: esm, evolla |
Original PR #40370 by zucchini-nlp Original: huggingface/transformers#40370
Merged from original PR #40370 Original: huggingface/transformers#40370
Original PR #40370 by zucchini-nlp Original: huggingface/transformers#40370
Merged from original PR #40370 Original: huggingface/transformers#40370
Original PR #40370 by zucchini-nlp Original: huggingface/transformers#40370
Merged from original PR #40370 Original: huggingface/transformers#40370
What does this PR do?
Addresses #34954 and updates ESM to supports attention API and modeling outputs