Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
289 changes: 103 additions & 186 deletions src/transformers/models/esm/modeling_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
"""PyTorch ESM model."""

import math
from typing import Optional, Union
from typing import Callable, Optional, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithCrossAttentions,
Expand All @@ -32,15 +30,16 @@
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import (
ALL_ATTENTION_FUNCTIONS,
PreTrainedModel,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...utils import auto_docstring, can_return_tuple, logging
from .configuration_esm import EsmConfig


if is_flash_attn_available():
from ...modeling_flash_attention_utils import _flash_attention_forward


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -274,7 +273,7 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None):
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.dropout_prob = config.attention_probs_dropout_prob
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
Expand All @@ -286,6 +285,7 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None):
self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)

self.is_decoder = config.is_decoder
self.is_causal = False
self.layer_idx = layer_idx

def forward(
Expand Down Expand Up @@ -323,53 +323,85 @@ def forward(
if self.position_embedding_type == "rotary":
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility

if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in EsmModel forward() function)
attention_scores = attention_scores + attention_mask

# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)

# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask

context_layer = torch.matmul(attention_probs.to(value_layer.dtype), value_layer)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
assert head_mask is None, "head_mask is only supported for eager attention"
assert "relative" not in self.position_embedding_type, (
"relative position embeddings are only supported for eager attention"
)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
attn_output, attn_weights = attention_interface(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
dropout=self.dropout_prob,
head_mask=head_mask,
hidden_states=hidden_states,
)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
new_context_layer_shape = attn_output.size()[:-2] + (self.all_head_size,)
attn_output = attn_output.view(new_context_layer_shape)

outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
if self.is_decoder:
outputs = outputs + (None,)
return outputs


def eager_attention_forward(
module: nn.Module,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
scaling: int = 1.0,
dropout: float = 0.0,
head_mask: Optional[torch.FloatTensor] = None,
hidden_states: Optional[torch.FloatTensor] = None,
**kwargs,
):
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility

if module.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif module.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in EsmModel forward() function)
attention_scores = attention_scores + attention_mask

# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = nn.functional.dropout(attention_probs, p=dropout, training=module.training)

# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask

context_layer = torch.matmul(attention_probs.to(value_layer.dtype), value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
return context_layer, attention_probs


class EsmSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
Expand All @@ -382,129 +414,10 @@ def forward(self, hidden_states, input_tensor):
hidden_states = hidden_states + input_tensor
return hidden_states


class EsmFlashAttention2(EsmSelfAttention):
"""
ESM flash attention module. This module inherits from `EsmSelfAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""

def __init__(self, config, position_embedding_type=None, layer_idx=None):
super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx)

# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
self.dropout_prob = config.attention_probs_dropout_prob

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]:
# Flash attention doesn't support output_attentions or cross attention
if output_attentions or head_mask is not None or encoder_hidden_states is not None:
logger.warning_once(
"EsmFlashAttention2 does not support output_attentions, head_mask, or cross_attention. "
"Falling back to the manual attention implementation. This warning can be removed using "
'the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)

bsz, q_len, _ = hidden_states.size()

query_layer = self.transpose_for_scores(self.query(hidden_states))
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32.
input_dtype = query_layer.dtype
device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.query.weight.dtype

logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)

query_layer = query_layer.to(target_dtype)
key_layer = key_layer.to(target_dtype)
value_layer = value_layer.to(target_dtype)

# Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
# ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
# but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
# ESM code and fix rotary embeddings.
query_layer = query_layer * self.attention_head_size**-0.5

if self.position_embedding_type == "rotary":
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
elif self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
raise ValueError(f"ESM flash attention does not support {self.position_embedding_type} embeddings")

# It would likely be faster to change self.transpose_for_scores to output the correct
# dimensions for flash_attention_2, but that would also mean changing the rotary embedding
# functions. Here we just permute the dimensions to match the expected input.
attn_output = _flash_attention_forward(
query_layer.permute(0, 2, 1, 3),
key_layer.permute(0, 2, 1, 3),
value_layer.permute(0, 2, 1, 3),
attention_mask,
query_length=q_len,
is_causal=self.is_decoder,
softmax_scale=1.0,
dropout=self.dropout_prob if self.training else 0.0,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)

attn_output = attn_output.reshape(bsz, q_len, -1)

outputs = (attn_output, None)
if self.is_decoder:
outputs = outputs + (None,)

return outputs


ESM_ATTENTION_CLASSES = {
"eager": EsmSelfAttention,
"flash_attention_2": EsmFlashAttention2,
}


class EsmAttention(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
self.self = ESM_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
self.self = EsmSelfAttention(config, layer_idx=layer_idx)
self.output = EsmSelfOutput(config)
self.pruned_heads = set()
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Expand Down Expand Up @@ -733,6 +646,8 @@ class EsmPreTrainedModel(PreTrainedModel):
_no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"]
_keys_to_ignore_on_load_unexpected = ["position_embeddings.weight"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True

# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->EsmLMHead
def _init_weights(self, module):
Expand Down Expand Up @@ -823,22 +738,22 @@ def forward(
return_dict: Optional[bool] = None,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
r"""
input_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`):
Indices of input sequence tokens in the vocabulary.

Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.

[What are input IDs?](../glossary#input-ids)
position_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.

[What are position IDs?](../glossary#position-ids)
inputs_embeds (`torch.FloatTensor` of shape `((batch_size, sequence_length), hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
input_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`):
Indices of input sequence tokens in the vocabulary.

Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
f
[What are input IDs?](../glossary#input-ids)
position_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.

[What are position IDs?](../glossary#position-ids)
inputs_embeds (`torch.FloatTensor` of shape `((batch_size, sequence_length), hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand All @@ -862,8 +777,10 @@ def forward(
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length)), device=device)

if self.config._attn_implementation == "flash_attention_2":
extended_attention_mask = attention_mask
if self.config._attn_implementation == "sdpa":
extended_attention_mask = attention_mask[:, None, None, :].to(bool)
elif self.config._attn_implementation == "flash_attention_2":
extended_attention_mask = attention_mask.to(bool)

else:
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
Expand Down
6 changes: 4 additions & 2 deletions tests/models/esm/test_modeling_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,13 @@ def test_flash_attn_2_equivalence(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)

model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="eager"
)
model.to(torch_device)

dummy_input = inputs_dict[model_class.main_input_name]
Expand Down