From 7b343084276d7a707618cfcf819371a5fa55cb97 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Fri, 15 Aug 2025 07:33:59 -0700 Subject: [PATCH 1/3] initial attempts Signed-off-by: Peter St. John --- src/transformers/models/esm/modeling_esm.py | 121 ++++++++++++-------- tests/models/esm/test_modeling_esm.py | 4 +- 2 files changed, 78 insertions(+), 47 deletions(-) diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 3bf625e92477..fd9c1008f2b9 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -16,7 +16,7 @@ """PyTorch ESM model.""" import math -from typing import Optional, Union +from typing import Callable, Optional, Union import torch import torch.utils.checkpoint @@ -32,7 +32,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS,PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, can_return_tuple, logging from .configuration_esm import EsmConfig @@ -274,7 +274,7 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.dropout_prob = config.attention_probs_dropout_prob self.position_embedding_type = position_embedding_type or getattr( config, "position_embedding_type", "absolute" ) @@ -323,53 +323,84 @@ def forward( if self.position_embedding_type == "rotary": query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in EsmModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs.to(value_layer.dtype), value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + assert head_mask is None, "head_mask is only supported for eager attention" + assert "relative" not in self.position_embedding_type, "relative position embeddings are only supported for eager attention" + + attn_output, attn_weights = attention_interface( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + dropout=self.dropout_prob, + head_mask=head_mask, + hidden_states=hidden_states, + ) + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) if self.is_decoder: outputs = outputs + (None,) return outputs +def eager_attention_forward( + module: nn.Module, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + scaling: int = 1.0, + dropout: float = 0.0, + head_mask: Optional[torch.FloatTensor] = None, + hidden_states: Optional[torch.FloatTensor] = None, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if module.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif module.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in EsmModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = nn.functional.dropout(attention_probs, p=dropout, training=module.training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs.to(value_layer.dtype), value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (module.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, attention_probs + + class EsmSelfOutput(nn.Module): def __init__(self, config): super().__init__() @@ -828,7 +859,7 @@ def forward( Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. - +f [What are input IDs?](../glossary#input-ids) position_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, diff --git a/tests/models/esm/test_modeling_esm.py b/tests/models/esm/test_modeling_esm.py index 79dd701efdbb..e0191454ffca 100644 --- a/tests/models/esm/test_modeling_esm.py +++ b/tests/models/esm/test_modeling_esm.py @@ -327,11 +327,11 @@ def test_flash_attn_2_equivalence(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2" + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" ) model_fa.to(torch_device) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager") + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="eager") model.to(torch_device) dummy_input = inputs_dict[model_class.main_input_name] From 178b3219c76805ac274c4f0ec8d70c78f8fd917d Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Fri, 15 Aug 2025 08:19:05 -0700 Subject: [PATCH 2/3] running format Signed-off-by: Peter St. John --- src/transformers/models/esm/modeling_esm.py | 43 ++++++++++++--------- tests/models/esm/test_modeling_esm.py | 4 +- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index fd9c1008f2b9..f5efe462fa13 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -32,7 +32,12 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS,PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import ( + ALL_ATTENTION_FUNCTIONS, + PreTrainedModel, + find_pruneable_heads_and_indices, + prune_linear_layer, +) from ...utils import auto_docstring, can_return_tuple, logging from .configuration_esm import EsmConfig @@ -327,7 +332,9 @@ def forward( if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] assert head_mask is None, "head_mask is only supported for eager attention" - assert "relative" not in self.position_embedding_type, "relative position embeddings are only supported for eager attention" + assert "relative" not in self.position_embedding_type, ( + "relative position embeddings are only supported for eager attention" + ) attn_output, attn_weights = attention_interface( self, @@ -854,22 +861,22 @@ def forward( return_dict: Optional[bool] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: r""" - input_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. -f - [What are input IDs?](../glossary#input-ids) - position_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - inputs_embeds (`torch.FloatTensor` of shape `((batch_size, sequence_length), hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. + input_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + f + [What are input IDs?](../glossary#input-ids) + position_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `((batch_size, sequence_length), hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( diff --git a/tests/models/esm/test_modeling_esm.py b/tests/models/esm/test_modeling_esm.py index e0191454ffca..4d7802508b3c 100644 --- a/tests/models/esm/test_modeling_esm.py +++ b/tests/models/esm/test_modeling_esm.py @@ -331,7 +331,9 @@ def test_flash_attn_2_equivalence(self): ) model_fa.to(torch_device) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="eager") + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="eager" + ) model.to(torch_device) dummy_input = inputs_dict[model_class.main_input_name] From fa9a4fd1850bba8dcbda2b188b463073ee8272ee Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Fri, 15 Aug 2025 14:54:54 -0700 Subject: [PATCH 3/3] model cleanup Signed-off-by: Peter St. John --- src/transformers/models/esm/modeling_esm.py | 143 ++------------------ 1 file changed, 11 insertions(+), 132 deletions(-) diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index f5efe462fa13..1baeb37490e7 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -19,11 +19,9 @@ from typing import Callable, Optional, Union import torch -import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithCrossAttentions, @@ -42,10 +40,6 @@ from .configuration_esm import EsmConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) @@ -291,6 +285,7 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size) self.is_decoder = config.is_decoder + self.is_causal = False self.layer_idx = layer_idx def forward( @@ -347,6 +342,9 @@ def forward( hidden_states=hidden_states, ) + new_context_layer_shape = attn_output.size()[:-2] + (self.all_head_size,) + attn_output = attn_output.view(new_context_layer_shape) + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) if self.is_decoder: outputs = outputs + (None,) @@ -400,11 +398,7 @@ def eager_attention_forward( attention_probs = attention_probs * head_mask context_layer = torch.matmul(attention_probs.to(value_layer.dtype), value_layer) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (module.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - return context_layer, attention_probs @@ -420,129 +414,10 @@ def forward(self, hidden_states, input_tensor): hidden_states = hidden_states + input_tensor return hidden_states - -class EsmFlashAttention2(EsmSelfAttention): - """ - ESM flash attention module. This module inherits from `EsmSelfAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, config, position_embedding_type=None, layer_idx=None): - super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - self.dropout_prob = config.attention_probs_dropout_prob - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = False, - ) -> tuple[torch.Tensor]: - # Flash attention doesn't support output_attentions or cross attention - if output_attentions or head_mask is not None or encoder_hidden_states is not None: - logger.warning_once( - "EsmFlashAttention2 does not support output_attentions, head_mask, or cross_attention. " - "Falling back to the manual attention implementation. This warning can be removed using " - 'the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - output_attentions, - ) - - bsz, q_len, _ = hidden_states.size() - - query_layer = self.transpose_for_scores(self.query(hidden_states)) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. - input_dtype = query_layer.dtype - device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu" - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = ( - torch.get_autocast_dtype(device_type) - if hasattr(torch, "get_autocast_dtype") - else torch.get_autocast_gpu_dtype() - ) - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.query.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_layer = query_layer.to(target_dtype) - key_layer = key_layer.to(target_dtype) - value_layer = value_layer.to(target_dtype) - - # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim). - # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent, - # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original - # ESM code and fix rotary embeddings. - query_layer = query_layer * self.attention_head_size**-0.5 - - if self.position_embedding_type == "rotary": - query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) - elif self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - raise ValueError(f"ESM flash attention does not support {self.position_embedding_type} embeddings") - - # It would likely be faster to change self.transpose_for_scores to output the correct - # dimensions for flash_attention_2, but that would also mean changing the rotary embedding - # functions. Here we just permute the dimensions to match the expected input. - attn_output = _flash_attention_forward( - query_layer.permute(0, 2, 1, 3), - key_layer.permute(0, 2, 1, 3), - value_layer.permute(0, 2, 1, 3), - attention_mask, - query_length=q_len, - is_causal=self.is_decoder, - softmax_scale=1.0, - dropout=self.dropout_prob if self.training else 0.0, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - - attn_output = attn_output.reshape(bsz, q_len, -1) - - outputs = (attn_output, None) - if self.is_decoder: - outputs = outputs + (None,) - - return outputs - - -ESM_ATTENTION_CLASSES = { - "eager": EsmSelfAttention, - "flash_attention_2": EsmFlashAttention2, -} - - class EsmAttention(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() - self.self = ESM_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + self.self = EsmSelfAttention(config, layer_idx=layer_idx) self.output = EsmSelfOutput(config) self.pruned_heads = set() self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -771,6 +646,8 @@ class EsmPreTrainedModel(PreTrainedModel): _no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"] _keys_to_ignore_on_load_unexpected = ["position_embeddings.weight"] _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->EsmLMHead def _init_weights(self, module): @@ -900,8 +777,10 @@ def forward( if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length)), device=device) - if self.config._attn_implementation == "flash_attention_2": - extended_attention_mask = attention_mask + if self.config._attn_implementation == "sdpa": + extended_attention_mask = attention_mask[:, None, None, :].to(bool) + elif self.config._attn_implementation == "flash_attention_2": + extended_attention_mask = attention_mask.to(bool) else: # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]