From 5d14a87727e8e63cf108be9522d5fc248ac799b8 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 4 Jun 2025 17:30:49 +0200 Subject: [PATCH 01/58] dump --- .../models/align/modeling_align.py | 133 ++++++------ .../models/altclip/modeling_altclip.py | 133 ++++++------ src/transformers/models/bert/modeling_bert.py | 193 ++++++++++-------- .../modeling_bert_generation.py | 135 ++++++------ .../bridgetower/modeling_bridgetower.py | 108 ++++++---- .../models/camembert/modeling_camembert.py | 191 +++++++++-------- .../chinese_clip/modeling_chinese_clip.py | 137 +++++++------ src/transformers/models/clap/modeling_clap.py | 133 ++++++------ src/transformers/models/clvp/modeling_clvp.py | 92 +++++---- .../models/cpmant/modeling_cpmant.py | 102 +++++---- src/transformers/models/ctrl/modeling_ctrl.py | 75 ++++--- .../models/data2vec/modeling_data2vec_text.py | 135 ++++++------ .../models/electra/modeling_electra.py | 133 ++++++------ .../models/ernie/modeling_ernie.py | 133 ++++++------ .../models/layoutlm/modeling_layoutlm.py | 133 ++++++------ src/transformers/models/led/modeling_led.py | 149 ++++++++------ src/transformers/models/lilt/modeling_lilt.py | 4 +- .../models/markuplm/modeling_markuplm.py | 133 ++++++------ .../megatron_bert/modeling_megatron_bert.py | 62 +++--- src/transformers/models/mpt/modeling_mpt.py | 77 ++++--- src/transformers/models/mvp/modeling_mvp.py | 154 +++++++------- .../models/nllb_moe/modeling_nllb_moe.py | 129 +++++++----- .../models/prophetnet/modeling_prophetnet.py | 175 +++++++++------- src/transformers/models/rag/modeling_rag.py | 15 +- .../models/rembert/modeling_rembert.py | 127 +++++++----- .../models/roberta/modeling_roberta.py | 191 +++++++++-------- .../modeling_roberta_prelayernorm.py | 127 +++++++----- .../models/roc_bert/modeling_roc_bert.py | 133 ++++++------ .../models/roformer/modeling_roformer.py | 126 +++++++----- .../seamless_m4t/modeling_seamless_m4t.py | 158 +++++++------- .../modeling_seamless_m4t_v2.py | 142 ++++++++----- .../speech_to_text/modeling_speech_to_text.py | 54 +++-- .../models/splinter/modeling_splinter.py | 133 ++++++------ .../models/superglue/modeling_superglue.py | 72 ++++--- .../models/tapas/modeling_tapas.py | 26 +-- src/transformers/models/xglm/modeling_xglm.py | 138 +++++++------ .../xlm_roberta/modeling_xlm_roberta.py | 191 +++++++++-------- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 120 ++++++----- src/transformers/models/xmod/modeling_xmod.py | 62 +++--- 39 files changed, 2681 insertions(+), 1983 deletions(-) diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index bdebd31a266c..813106d14aec 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -23,6 +23,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_outputs import ( BaseModelOutputWithNoAttention, BaseModelOutputWithPastAndCrossAttentions, @@ -590,7 +591,7 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->AlignText class AlignTextSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -615,6 +616,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -628,8 +630,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -638,43 +641,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -745,10 +749,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->AlignText,BERT->ALIGN_TEXT class AlignTextAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = ALIGN_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = AlignTextSelfOutput(config) self.pruned_heads = set() @@ -778,8 +784,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -789,6 +796,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -828,17 +836,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->AlignText class AlignTextLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = AlignTextAttention(config) + self.attention = AlignTextAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = AlignTextAttention(config, position_embedding_type="absolute") + self.crossattention = AlignTextAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = AlignTextIntermediate(config) self.output = AlignTextOutput(config) @@ -849,28 +857,26 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -878,24 +884,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -903,7 +904,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -915,10 +916,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->AlignText class AlignTextEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([AlignTextLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([AlignTextLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -933,6 +934,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -945,13 +947,22 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -961,8 +972,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -971,13 +983,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -986,12 +999,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -1000,7 +1017,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 5637bc6deec3..023fc68ca10a 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -23,6 +23,7 @@ import torch.utils.checkpoint from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -181,7 +182,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->AltRoberta class AltRobertaSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -206,6 +207,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -219,8 +221,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -229,43 +232,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -336,10 +340,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->AltRoberta,ROBERTA->ALT_ROBERTA class AltRobertaAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = ALT_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = AltRobertaSelfOutput(config) self.pruned_heads = set() @@ -369,8 +375,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -380,6 +387,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -419,17 +427,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->AltRoberta class AltRobertaLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = AltRobertaAttention(config) + self.attention = AltRobertaAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = AltRobertaAttention(config, position_embedding_type="absolute") + self.crossattention = AltRobertaAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = AltRobertaIntermediate(config) self.output = AltRobertaOutput(config) @@ -440,28 +448,26 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -469,24 +475,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -494,7 +495,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -506,10 +507,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->AltRoberta class AltRobertaEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([AltRobertaLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([AltRobertaLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -524,6 +525,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -536,13 +538,22 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -552,8 +563,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -562,13 +574,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -577,12 +590,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -591,7 +608,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index dd738ea96ebd..4b8a0d9c0a17 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -28,6 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( @@ -188,7 +189,7 @@ def forward( class BertSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -213,6 +214,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -226,8 +228,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -236,43 +239,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -322,8 +326,8 @@ def forward( class BertSdpaSelfAttention(BertSelfAttention): - def __init__(self, config, position_embedding_type=None): - super().__init__(config, position_embedding_type=position_embedding_type) + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) self.dropout_prob = config.attention_probs_dropout_prob self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") @@ -335,13 +339,14 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. logger.warning_once( - "BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "XLMRobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " "the manual attention implementation, but specifying the manual implementation will be required from " "Transformers version v5.0.0 onwards. This warning can be removed using the argument " @@ -355,6 +360,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) bsz, tgt_len, _ = hidden_states.size() @@ -368,25 +374,35 @@ def forward( current_states = encoder_hidden_states if is_cross_attention else hidden_states attention_mask = encoder_attention_mask if is_cross_attention else attention_mask - # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: - key_layer, value_layer = past_key_value - else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) - if past_key_value is not None and not is_cross_attention: - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] + else: + value_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. @@ -443,10 +459,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class BertAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = BERT_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = BertSelfOutput(config) self.pruned_heads = set() @@ -476,8 +494,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -487,6 +506,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -523,17 +543,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class BertLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = BertAttention(config) + self.attention = BertAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = BertAttention(config, position_embedding_type="absolute") + self.crossattention = BertAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) @@ -544,28 +564,26 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -573,24 +591,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -598,7 +611,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -609,10 +622,10 @@ def feed_forward_chunk(self, attention_output): class BertEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([BertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -627,6 +640,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -639,13 +653,22 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -655,8 +678,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -665,13 +689,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -680,12 +705,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -694,7 +723,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -913,6 +942,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1024,6 +1054,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None @@ -1190,6 +1221,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **loss_kwargs, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" @@ -1216,6 +1248,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = outputs[0] diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 6dee1db6fbdc..d6cce5b2520f 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -22,6 +22,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel @@ -53,7 +54,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->BertGeneration class BertGenerationSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -78,6 +79,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -91,8 +93,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -101,43 +104,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -193,10 +197,12 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BertGeneration,BERT->BERT_GENERATION class BertGenerationAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = BERT_GENERATION_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = BertGenerationSelfOutput(config) self.pruned_heads = set() @@ -226,8 +232,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -237,6 +244,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -276,17 +284,19 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->BertGeneration class BertGenerationLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = BertGenerationAttention(config) + self.attention = BertGenerationAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = BertGenerationAttention(config, position_embedding_type="absolute") + self.crossattention = BertGenerationAttention( + config, position_embedding_type="absolute", layer_idx=layer_idx + ) self.intermediate = BertGenerationIntermediate(config) self.output = BertGenerationOutput(config) @@ -297,28 +307,26 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -326,24 +334,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -351,7 +354,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -363,10 +366,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->BertGeneration class BertEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([BertGenerationLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([BertGenerationLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -381,6 +384,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -393,13 +397,22 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -409,8 +422,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -419,13 +433,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -434,12 +449,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -448,7 +467,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index e9ba3f272ce3..60cec997e500 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN, QuickGELUActivation +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -410,7 +411,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->BridgeTower class BridgeTowerSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -435,6 +436,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -448,8 +450,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -458,43 +461,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -550,10 +554,12 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BridgeTower,BERT->BRIDGE_TOWER class BridgeTowerAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = BRIDGE_TOWER_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = BridgeTowerSelfOutput(config) self.pruned_heads = set() @@ -583,8 +589,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -594,6 +601,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -750,10 +758,12 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->BridgeTowerText class BridgeTowerTextEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([BridgeTowerTextLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList( + [BridgeTowerTextLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) self.gradient_checkpointing = False def forward( @@ -768,6 +778,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -780,13 +791,22 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -796,8 +816,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -806,13 +827,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -821,12 +843,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -835,7 +861,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 733f31307892..bc241dce4438 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( @@ -138,7 +139,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->Camembert class CamembertSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -163,6 +164,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -176,8 +178,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -186,43 +189,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -273,8 +277,8 @@ def forward( # Copied from transformers.models.roberta.modeling_roberta.RobertaSdpaSelfAttention with Roberta->Camembert class CamembertSdpaSelfAttention(CamembertSelfAttention): - def __init__(self, config, position_embedding_type=None): - super().__init__(config, position_embedding_type=position_embedding_type) + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) self.dropout_prob = config.attention_probs_dropout_prob self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") @@ -286,13 +290,14 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. logger.warning_once( - "CamembertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "XLMCamembertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " "the manual attention implementation, but specifying the manual implementation will be required from " "Transformers version v5.0.0 onwards. This warning can be removed using the argument " @@ -306,6 +311,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) bsz, tgt_len, _ = hidden_states.size() @@ -319,25 +325,35 @@ def forward( current_states = encoder_hidden_states if is_cross_attention else hidden_states attention_mask = encoder_attention_mask if is_cross_attention else attention_mask - # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: - key_layer, value_layer = past_key_value - else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) - if past_key_value is not None and not is_cross_attention: - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] + else: + value_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. @@ -396,10 +412,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->Camembert,ROBERTA->CAMEMBERT class CamembertAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = CAMEMBERT_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = CamembertSelfOutput(config) self.pruned_heads = set() @@ -429,8 +447,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -440,6 +459,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -479,17 +499,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->Camembert class CamembertLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = CamembertAttention(config) + self.attention = CamembertAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = CamembertAttention(config, position_embedding_type="absolute") + self.crossattention = CamembertAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = CamembertIntermediate(config) self.output = CamembertOutput(config) @@ -500,28 +520,26 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -529,24 +547,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -554,7 +567,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -566,10 +579,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->Camembert class CamembertEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([CamembertLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([CamembertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -584,6 +597,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -596,13 +610,22 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -612,8 +635,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -622,13 +646,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -637,12 +662,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -651,7 +680,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -826,6 +855,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -937,6 +967,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index bc1421e7157a..747f3ba7f114 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -23,6 +23,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -240,7 +241,7 @@ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=Fals # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->ChineseCLIPText class ChineseCLIPTextSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -265,6 +266,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -278,8 +280,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -288,43 +291,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -395,10 +399,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ChineseCLIPText,BERT->CHINESE_CLIP_TEXT class ChineseCLIPTextAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = ChineseCLIPTextSelfOutput(config) self.pruned_heads = set() @@ -428,8 +434,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -439,6 +446,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -578,17 +586,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ChineseCLIPText class ChineseCLIPTextLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ChineseCLIPTextAttention(config) + self.attention = ChineseCLIPTextAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = ChineseCLIPTextAttention(config, position_embedding_type="absolute") + self.crossattention = ChineseCLIPTextAttention( + config, position_embedding_type="absolute", layer_idx=layer_idx + ) self.intermediate = ChineseCLIPTextIntermediate(config) self.output = ChineseCLIPTextOutput(config) @@ -599,28 +609,26 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -628,24 +636,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -653,7 +656,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -778,10 +781,12 @@ def _init_weights(self, module): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->ChineseCLIPText class ChineseCLIPTextEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([ChineseCLIPTextLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList( + [ChineseCLIPTextLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) self.gradient_checkpointing = False def forward( @@ -796,6 +801,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -808,13 +814,22 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -824,8 +839,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -834,13 +850,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -849,12 +866,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -863,7 +884,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index a08849071701..63736ba2cfad 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -24,6 +24,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPooling, @@ -1118,7 +1119,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->ClapText class ClapTextSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -1143,6 +1144,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -1156,8 +1158,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -1166,43 +1169,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -1273,10 +1277,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ClapText,BERT->CLAP_TEXT class ClapTextAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = CLAP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = ClapTextSelfOutput(config) self.pruned_heads = set() @@ -1306,8 +1312,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -1317,6 +1324,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -1356,17 +1364,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ClapText class ClapTextLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ClapTextAttention(config) + self.attention = ClapTextAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = ClapTextAttention(config, position_embedding_type="absolute") + self.crossattention = ClapTextAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = ClapTextIntermediate(config) self.output = ClapTextOutput(config) @@ -1377,28 +1385,26 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -1406,24 +1412,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -1431,7 +1432,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -1443,10 +1444,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->ClapText class ClapTextEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([ClapTextLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([ClapTextLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -1461,6 +1462,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -1473,13 +1475,22 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -1489,8 +1500,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -1499,13 +1511,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -1514,12 +1527,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -1528,7 +1545,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index 677858fe804e..e5d4b5de9b01 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN, get_activation +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationConfig, GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( @@ -274,7 +275,7 @@ class ClvpSelfAttention(nn.Module): Multi-headed attention to combine Absolute and Rotary Positional Embeddings into a single Attention module. """ - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -287,6 +288,7 @@ def __init__(self, config): ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout + self.layer_idx = layer_idx if hasattr(config, "max_position_embeddings"): max_positions = config.max_position_embeddings @@ -308,10 +310,11 @@ def forward( rotary_pos_emb: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, use_cache: Optional[bool] = False, head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]: # Raise error when position_ids is None but rotary_pos_emb is provided, because we need that when applying # rotary_pos_emb to query and key states. @@ -326,14 +329,9 @@ def forward( value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if past_key_value is not None: - past_key, past_value = past_key_value - key_states = torch.cat((past_key, key_states), dim=-2) - value_states = torch.cat((past_value, value_states), dim=-2) - - if use_cache is True: - present = (key_states, value_states) - else: - present = None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) if rotary_pos_emb is not None: rotary_emb_dim = rotary_pos_emb.shape[-1] @@ -394,7 +392,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, present, attn_weights + return attn_output, past_key_value, attn_weights class ClvpGatedLinearUnit(nn.Module): @@ -614,13 +612,13 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl class ClvpDecoderLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = ClvpSelfAttention(config) + self.attn = ClvpSelfAttention(config, layer_idx=layer_idx) self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = ClvpDecoderMLP(inner_dim, config) @@ -628,12 +626,13 @@ def __init__(self, config): def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -645,6 +644,7 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) attn_output = attn_outputs[0] outputs = attn_outputs[1:] @@ -1013,7 +1013,9 @@ def __init__(self, config): self.position_embeds_layer = nn.Embedding(self.config.max_position_embeddings, self.config.hidden_size) self.drop = nn.Dropout(self.config.embd_pdrop) - self.layers = nn.ModuleList([ClvpDecoderLayer(self.config) for _ in range(self.config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [ClvpDecoderLayer(self.config, layer_idx=i) for i in range(self.config.num_hidden_layers)] + ) self.layer_norm = nn.LayerNorm(self.config.hidden_size, eps=self.config.layer_norm_epsilon) self.gradient_checkpointing = False @@ -1048,6 +1050,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1074,11 +1077,24 @@ def forward( if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if past_key_values is None: - past_key_values_length = 0 - past_key_values = tuple([None] * len(self.layers)) - else: - past_key_values_length = past_key_values[0][0].size(-2) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if position_ids is None: position_ids = torch.arange( past_key_values_length, input_shape[-1] + past_key_values_length, dtype=torch.long, device=device @@ -1110,18 +1126,11 @@ def forward( output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None + next_decoder_cache = None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None - for i, (block, past_key_value) in enumerate(zip(self.layers, past_key_values)): + for i, block in enumerate(self.layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -1133,21 +1142,23 @@ def forward( attention_mask, position_ids, head_mask[i], + cache_position, ) else: outputs = block( hidden_states, - past_key_value=past_key_value, + past_key_value=past_key_values, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = outputs[0] if use_cache is True: - presents = presents + (outputs[1],) + next_decoder_cache = outputs[1] if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) @@ -1162,16 +1173,20 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v - for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -1211,6 +1226,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1232,6 +1248,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1346,7 +1363,7 @@ def prepare_inputs_for_generation( token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs if past_key_values: - past_length = past_key_values[0][0].shape[2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: @@ -1366,8 +1383,7 @@ def prepare_inputs_for_generation( # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :].unsqueeze(-1) else: position_ids = None @@ -1405,6 +1421,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1432,6 +1449,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] @@ -1716,6 +1734,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = False, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, ClvpOutput]: r""" input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, time_dim)`): @@ -1771,6 +1790,7 @@ def forward( inputs_embeds=conditioning_embeds, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) speech_ids = decoder_outputs[0] diff --git a/src/transformers/models/cpmant/modeling_cpmant.py b/src/transformers/models/cpmant/modeling_cpmant.py index e437672aa58e..a684cd450aee 100755 --- a/src/transformers/models/cpmant/modeling_cpmant.py +++ b/src/transformers/models/cpmant/modeling_cpmant.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel @@ -60,11 +61,12 @@ def forward(self, hidden_states: torch.Tensor): class CpmAntAttention(nn.Module): - def __init__(self, config: CpmAntConfig): + def __init__(self, config: CpmAntConfig, layer_idx=None): super().__init__() self.dim_model = config.hidden_size self.num_heads = config.num_attention_heads self.dim_head = config.dim_head + self.layer_idx = layer_idx self.project_q = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False) self.project_k = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False) @@ -86,8 +88,9 @@ def forward( attention_mask: torch.BoolTensor, position_bias: torch.Tensor, output_attentions: Optional[bool] = False, - past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ): """ Args: @@ -120,8 +123,7 @@ def forward( value = value.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3) if past_key_values is not None: - key = torch.cat([past_key_values[0], key], dim=-2) - value = torch.cat([past_key_values[1], value], dim=-2) + key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position}) len_k = key.size(-2) # (batch_size, num_heads, len_q, dim_head) @ (batch_size, num_heads, dim_head, len_k) -> (batch_size, num_heads, len_q, len_k) @@ -156,18 +158,14 @@ def forward( score = self.attention_out(score) - past_key_values = None - if use_cache: - past_key_values = (key, value) - return score, attn_weights, past_key_values class CpmAntSelfAttentionBlock(nn.Module): - def __init__(self, config: CpmAntConfig): + def __init__(self, config: CpmAntConfig, layer_idx=None): super().__init__() self.layernorm_before_attention = CpmAntLayerNorm(config) - self.self_attention = CpmAntAttention(config) + self.self_attention = CpmAntAttention(config, layer_idx=layer_idx) if config.dropout_p: self.dropout = torch.nn.Dropout(config.dropout_p) else: @@ -179,8 +177,9 @@ def forward( attention_mask: torch.Tensor, position_bias: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, - past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ): """ Args: @@ -200,7 +199,14 @@ def forward( """ outputs = self.layernorm_before_attention(hidden_states) outputs = self.self_attention( - outputs, outputs, attention_mask, position_bias, output_attentions, past_key_values, use_cache + outputs, + outputs, + attention_mask, + position_bias, + output_attentions, + past_key_values, + use_cache, + cache_position, ) outputs, attn_weights, current_key_value = outputs @@ -286,9 +292,9 @@ def forward( class CpmAntTransformerBlock(nn.Module): - def __init__(self, config: CpmAntConfig): + def __init__(self, config: CpmAntConfig, layer_idx=None): super().__init__() - self.self_att = CpmAntSelfAttentionBlock(config) + self.self_att = CpmAntSelfAttentionBlock(config, layer_idx=layer_idx) self.ffn = CpmAntFFNBlock(config) def forward( @@ -297,8 +303,9 @@ def forward( attention_mask: torch.Tensor, position_bias: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, - past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ): """ Args: @@ -323,6 +330,7 @@ def forward( output_attentions=output_attentions, past_key_values=past_key_values, use_cache=use_cache, + cache_position=cache_position, ) hidden_states, attn_weights, current_key_value = hidden_states @@ -336,7 +344,7 @@ class CpmAntEncoder(nn.Module): def __init__(self, config: CpmAntConfig): super().__init__() self.num_layers = config.num_hidden_layers - self.layers = nn.ModuleList([CpmAntTransformerBlock(config) for ith in range(self.num_layers)]) + self.layers = nn.ModuleList([CpmAntTransformerBlock(config, layer_idx=i) for i in range(self.num_layers)]) self.output_layernorm = CpmAntLayerNorm(config) @@ -347,8 +355,9 @@ def forward( position_bias: torch.Tensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = None, + cache_postion: Optional[torch.Tensor] = None, ): """ Args: @@ -370,7 +379,6 @@ def forward( """ all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - current_key_values = () if use_cache else None for i, layer in enumerate(self.layers): if output_hidden_states: @@ -383,18 +391,16 @@ def forward( past_key_values=past_key_values[i] if past_key_values else None, use_cache=use_cache, ) - hidden_states, attn_weights, current_key_value = layer_outputs + hidden_states, attn_weights, past_key_values = layer_outputs if output_attentions: all_self_attns += (attn_weights,) - if current_key_value is not None: - current_key_values = current_key_values + (current_key_value,) hidden_states = self.output_layernorm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) - return hidden_states, current_key_values, all_hidden_states, all_self_attns + return hidden_states, past_key_values, all_hidden_states, all_self_attns # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->CPMAnt @@ -592,6 +598,7 @@ def forward( past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]: r""" @@ -634,17 +641,24 @@ def forward( position = torch.arange(seq_length, dtype=dtype, device=device).repeat(batch, 1) span = torch.full((batch, seq_length), 0, dtype=dtype, device=device) - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * self.encoder.num_layers) - input_ids = input_ids.contiguous() - hidden_states = self.input_embedding(input_ids) - segment_states = self.segment_embedding(segment) - hidden_states = hidden_states + segment_states - else: - past_length = past_key_values[0][0].size(-2) - segment_states = self.segment_embedding(segment) - hidden_states = self.input_embedding(input_ids) + segment_states[:, -1:, :] + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + input_ids = input_ids.contiguous() + hidden_states = self.input_embedding(input_ids) + segment_states = self.segment_embedding(segment) + if past_length != 0: + segment_states = segment_states[:, -1:, :] + + hidden_states = hidden_states + segment_states attention_mask = self._prepare_attention_mask(input_ids, span, context, length) position_bias = self.position_bias(position, position, segment, segment) @@ -653,7 +667,7 @@ def forward( position_bias = position_bias[:, :, past_length:, :] hidden_states = hidden_states[:, past_length:, :] - hidden_states, present_key_values, all_hidden_states, all_attentions = self.encoder( + hidden_states, next_decoder_cache, all_hidden_states, all_attentions = self.encoder( hidden_states, attention_mask, position_bias, @@ -661,6 +675,7 @@ def forward( output_hidden_states, past_key_values, use_cache, + cache_position, ) if past_length == 0: @@ -677,14 +692,16 @@ def forward( new_hidden_states += (hidden_state[:, self.prompt_length :, :],) all_hidden_states = new_hidden_states + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: - return tuple( - v for v in [hidden_states, present_key_values, all_hidden_states, all_attentions] if v is not None - ) + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=present_key_values, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, ) @@ -719,6 +736,7 @@ def forward( labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, attention_mask: Optional[torch.Tensor] = None, # dummy parameter for text-generation pipeline + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -751,7 +769,13 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict model_output = self.cpmant( - input_ids, output_attentions, output_hidden_states, past_key_values, use_cache, return_dict + input_ids, + output_attentions, + output_hidden_states, + past_key_values, + use_cache, + return_dict, + cache_position, ) hidden_states = model_output.last_hidden_state if return_dict else model_output[0] diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index 1896d6ea4130..5375e4c2d49c 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ...cache_utils import DynamicCache from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel @@ -83,10 +84,11 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N class MultiHeadAttention(nn.Module): - def __init__(self, d_model_size, num_heads): + def __init__(self, d_model_size, num_heads, layer_idx=None): super().__init__() self.num_heads = num_heads self.d_model_size = d_model_size + self.layer_idx = layer_idx self.depth = int(d_model_size / self.num_heads) @@ -129,6 +131,7 @@ def forward( head_mask=None, use_cache=False, output_attentions=False, + cache_position=None, ): batch_size = q.shape[0] @@ -139,15 +142,9 @@ def forward( q = self.split_into_heads(q, batch_size) k = self.split_into_heads(k, batch_size) v = self.split_into_heads(v, batch_size) - if layer_past is not None: - past_key, past_value = layer_past[0], layer_past[1] - k = torch.cat((past_key, k), dim=-2) - v = torch.cat((past_value, v), dim=-2) - if use_cache is True: - present = torch.stack((k, v)) - else: - present = (None,) + if layer_past is not None: + k, v = layer_past.update(k, v, self.layer_idx, {"cache_position": cache_position}) output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask) scaled_attention = output[0].permute([0, 2, 1, 3]) @@ -155,7 +152,7 @@ def forward( original_size_attention = scaled_attention.reshape(batch_size, -1, self.d_model_size) output = self.dense(original_size_attention) - outputs = (output, present) + outputs = (output, layer_past) if output_attentions: outputs = outputs + (attn,) return outputs @@ -166,10 +163,10 @@ def point_wise_feed_forward_network(d_model_size, dff): class EncoderLayer(nn.Module): - def __init__(self, d_model_size, num_heads, dff, rate=0.1): + def __init__(self, d_model_size, num_heads, dff, rate=0.1, layer_idx=None): super().__init__() - self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads) + self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads, layer_idx=layer_idx) self.ffn = point_wise_feed_forward_network(d_model_size, dff) self.layernorm1 = nn.LayerNorm(d_model_size, eps=1e-6) @@ -179,7 +176,15 @@ def __init__(self, d_model_size, num_heads, dff, rate=0.1): self.dropout2 = nn.Dropout(rate) def forward( - self, x, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False + self, + x, + mask, + layer_past=None, + attention_mask=None, + head_mask=None, + use_cache=False, + output_attentions=False, + cache_position=None, ): normed = self.layernorm1(x) attn_outputs = self.multi_head_attention( @@ -192,6 +197,7 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) attn_output = attn_outputs[0] attn_output = self.dropout1(attn_output) @@ -242,7 +248,10 @@ def __init__(self, config): self.dropout = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList( - [EncoderLayer(config.n_embd, config.n_head, config.dff, config.resid_pdrop) for _ in range(config.n_layer)] + [ + EncoderLayer(config.n_embd, config.n_head, config.dff, config.resid_pdrop, layer_idx=i) + for i in range(config.n_layer) + ] ) self.layernorm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) @@ -276,6 +285,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, # NOOP kwargs, for now ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]: r""" @@ -332,11 +342,17 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if position_ids is None: position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0) @@ -387,24 +403,25 @@ def forward( hidden_states = self.dropout(hidden_states) - presents = () if use_cache else None + next_decoder_cache = None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None - for i, (h, layer_past) in enumerate(zip(self.h, past_key_values)): + for i, h in enumerate(self.h): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) outputs = h( hidden_states, mask, - layer_past=layer_past, + layer_past=past_key_values, attention_mask=attention_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states, present = outputs[:2] + hidden_states = outputs[0] if use_cache is True: - presents = presents + (present,) + next_decoder_cache = outputs[1] if output_attentions: all_attentions += (outputs[2],) @@ -413,12 +430,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, ) @@ -462,6 +483,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]: r""" @@ -520,6 +542,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = transformer_outputs[0] diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index a1d747476df8..8ed35e826258 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -138,7 +139,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->Data2VecText class Data2VecTextSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -163,6 +164,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -176,8 +178,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -186,43 +189,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -293,10 +297,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Data2VecText,BERT->DATA2VEC_TEXT class Data2VecTextAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = DATA2VEC_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = Data2VecTextSelfOutput(config) self.pruned_heads = set() @@ -326,8 +332,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -337,6 +344,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -376,17 +384,19 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Data2VecText class Data2VecTextLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = Data2VecTextAttention(config) + self.attention = Data2VecTextAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = Data2VecTextAttention(config, position_embedding_type="absolute") + self.crossattention = Data2VecTextAttention( + config, position_embedding_type="absolute", layer_idx=layer_idx + ) self.intermediate = Data2VecTextIntermediate(config) self.output = Data2VecTextOutput(config) @@ -397,28 +407,26 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -426,24 +434,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -451,7 +454,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -463,10 +466,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Data2VecText class Data2VecTextEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([Data2VecTextLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([Data2VecTextLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -481,6 +484,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -493,13 +497,22 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -509,8 +522,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -519,13 +533,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -534,12 +549,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -548,7 +567,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 0ebfa3d479ca..a75c00211edf 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, get_activation +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithCrossAttentions, @@ -199,7 +200,7 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Electra class ElectraSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -224,6 +225,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -237,8 +239,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -247,43 +250,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -354,10 +358,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Electra,BERT->ELECTRA class ElectraAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = ELECTRA_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = ElectraSelfOutput(config) self.pruned_heads = set() @@ -387,8 +393,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -398,6 +405,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -437,17 +445,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Electra class ElectraLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ElectraAttention(config) + self.attention = ElectraAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = ElectraAttention(config, position_embedding_type="absolute") + self.crossattention = ElectraAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = ElectraIntermediate(config) self.output = ElectraOutput(config) @@ -458,28 +466,26 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -487,24 +493,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -512,7 +513,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -524,10 +525,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Electra class ElectraEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([ElectraLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([ElectraLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -542,6 +543,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -554,13 +556,22 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -570,8 +581,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -580,13 +592,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -595,12 +608,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -609,7 +626,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index c07627fa792f..54c0bfa97bc6 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -124,7 +125,7 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Ernie class ErnieSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -149,6 +150,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -162,8 +164,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -172,43 +175,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -279,10 +283,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Ernie,BERT->ERNIE class ErnieAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = ERNIE_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = ErnieSelfOutput(config) self.pruned_heads = set() @@ -312,8 +318,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -323,6 +330,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -362,17 +370,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Ernie class ErnieLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ErnieAttention(config) + self.attention = ErnieAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = ErnieAttention(config, position_embedding_type="absolute") + self.crossattention = ErnieAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = ErnieIntermediate(config) self.output = ErnieOutput(config) @@ -383,28 +391,26 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -412,24 +418,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -437,7 +438,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -449,10 +450,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Ernie class ErnieEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([ErnieLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([ErnieLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -467,6 +468,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -479,13 +481,22 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -495,8 +506,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -505,13 +517,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -520,12 +533,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -534,7 +551,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 68d05e509f97..c8fb1fbbb16e 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -121,7 +122,7 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->LayoutLM class LayoutLMSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -146,6 +147,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -159,8 +161,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -169,43 +172,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -276,10 +280,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->LayoutLM,BERT->LAYOUTLM class LayoutLMAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = LAYOUTLM_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = LayoutLMSelfOutput(config) self.pruned_heads = set() @@ -309,8 +315,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -320,6 +327,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -359,17 +367,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->LayoutLM class LayoutLMLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = LayoutLMAttention(config) + self.attention = LayoutLMAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = LayoutLMAttention(config, position_embedding_type="absolute") + self.crossattention = LayoutLMAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = LayoutLMIntermediate(config) self.output = LayoutLMOutput(config) @@ -380,28 +388,26 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -409,24 +415,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -434,7 +435,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -446,10 +447,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->LayoutLM class LayoutLMEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([LayoutLMLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([LayoutLMLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -464,6 +465,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -476,13 +478,22 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -492,8 +503,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -502,13 +514,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -517,12 +530,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -531,7 +548,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index e6e21ce897da..2b2d17f28d8a 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions @@ -764,9 +765,10 @@ def __init__( self, embed_dim: int, num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + layer_idx: Optional[bool] = None, ): super().__init__() self.embed_dim = embed_dim @@ -780,6 +782,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -793,11 +796,12 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + cache_position: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer @@ -807,40 +811,44 @@ def forward( # get query proj query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) @@ -963,7 +971,7 @@ def forward( class LEDDecoderLayer(nn.Module): - def __init__(self, config: LEDConfig): + def __init__(self, config: LEDConfig, layer_idx=None): super().__init__() self.embed_dim = config.d_model @@ -972,6 +980,7 @@ def __init__(self, config: LEDConfig): num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -983,6 +992,7 @@ def __init__(self, config: LEDConfig): config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -997,9 +1007,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ): """ Args: @@ -1021,15 +1032,13 @@ def forward( residual = hidden_states # Self-Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -1041,23 +1050,19 @@ def forward( if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) @@ -1073,7 +1078,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -1761,7 +1766,7 @@ def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = Non self.max_target_positions, config.d_model, ) - self.layers = nn.ModuleList([LEDDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([LEDDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.layernorm_embedding = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -1783,6 +1788,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): r""" Args: @@ -1876,12 +1882,27 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None @@ -1911,18 +1932,11 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1941,8 +1955,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, @@ -1955,6 +1967,7 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -1966,15 +1979,16 @@ def forward( cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1985,6 +1999,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -2049,6 +2066,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], LEDSeq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -2134,6 +2152,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -2219,6 +2238,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], LEDSeq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -2338,6 +2358,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 9839e13f3666..d56a1d9ef85c 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -481,10 +481,10 @@ def layout_feed_forward_chunk(self, attention_output): class LiltEncoder(nn.Module): # Copied from transformers.models.bert.modeling_bert.BertEncoder.__init__ with Bert->Lilt - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([LiltLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([LiltLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index f2ebe388de3e..e929b9141f9e 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -327,7 +328,7 @@ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->MarkupLM class MarkupLMSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -352,6 +353,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -365,8 +367,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -375,43 +378,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -467,10 +471,12 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->MarkupLM,BERT->MARKUPLM class MarkupLMAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = MARKUPLM_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = MarkupLMSelfOutput(config) self.pruned_heads = set() @@ -500,8 +506,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -511,6 +518,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -519,17 +527,17 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->MarkupLM class MarkupLMLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = MarkupLMAttention(config) + self.attention = MarkupLMAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = MarkupLMAttention(config, position_embedding_type="absolute") + self.crossattention = MarkupLMAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = MarkupLMIntermediate(config) self.output = MarkupLMOutput(config) @@ -540,28 +548,26 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -569,24 +575,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -594,7 +595,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -606,10 +607,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->MarkupLM class MarkupLMEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([MarkupLMLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([MarkupLMLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -624,6 +625,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -636,13 +638,22 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -652,8 +663,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -662,13 +674,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -677,12 +690,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -691,7 +708,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 557cdb1fee75..eeafbbf5e5b8 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -177,7 +178,7 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->MegatronBert class MegatronBertSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -202,6 +203,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -215,8 +217,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -225,43 +228,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 10e91c988a74..dff4553b6a34 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss from torch.nn import functional as F +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import ( @@ -68,7 +69,7 @@ class MptAttention(nn.Module): Using torch or triton attention implementation enables user to also use additive bias. """ - def __init__(self, config: MptConfig): + def __init__(self, config: MptConfig, layer_idx: Optional[int] = None): super().__init__() self.hidden_size = config.hidden_size self.n_heads = config.n_heads @@ -82,13 +83,15 @@ def __init__(self, config: MptConfig): self.clip_qkv = config.attn_config.clip_qkv self.Wqkv = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False) self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.layer_idx = layer_idx def forward( self, hidden_states: torch.Tensor, position_bias: torch.Tensor, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, ): batch_size, seq_length = hidden_states.shape[:2] @@ -102,12 +105,8 @@ def forward( value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) if past_key_value is not None: - if len(past_key_value) != 0: - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - past_key_value = (key_states, value_states) - else: - past_key_value = (key_states, value_states) + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale @@ -161,7 +160,7 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch. class MptBlock(nn.Module): - def __init__(self, config: MptConfig): + def __init__(self, config: MptConfig, layer_idx: Optional[int] = None): super().__init__() hidden_size = config.hidden_size @@ -170,7 +169,7 @@ def __init__(self, config: MptConfig): self.norm_1.bias = None self.num_heads = config.n_heads - self.attn = MptAttention(config) + self.attn = MptAttention(config, layer_idx) self.norm_2 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) # backward compatibility with weights on the Hub @@ -186,9 +185,10 @@ def forward( hidden_states: torch.Tensor, position_bias: torch.Tensor, attention_mask: torch.Tensor, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + layer_past: Optional[Cache] = None, use_cache: bool = False, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ): # hidden_states: [batch_size, seq_length, hidden_size] # Layer norm at the beginning of the transformer layer. @@ -202,6 +202,7 @@ def forward( position_bias=position_bias, attention_mask=attention_mask, past_key_value=layer_past, + cache_position=cache_position, ) hidden_states = self.resid_attn_dropout(attn_outputs) + residual @@ -284,7 +285,7 @@ def __init__(self, config: MptConfig): self.wte = nn.Embedding(config.vocab_size, self.hidden_size) # Transformer blocks - self.blocks = nn.ModuleList([MptBlock(config) for _ in range(config.n_layers)]) + self.blocks = nn.ModuleList([MptBlock(config, layer_idx=i) for i in range(config.n_layers)]) # Final Layer Norm self.norm_f = LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon) @@ -309,13 +310,14 @@ def set_input_embeddings(self, new_embeddings: torch.Tensor): def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + past_key_values: Optional[Union[Tuple[Tuple[torch.Tensor, torch.Tensor], ...], Cache]] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, # NOOP kwargs, for now ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: r""" @@ -347,25 +349,32 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - if past_key_values is None: - past_key_values = tuple([None] * len(self.blocks)) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False if inputs_embeds is None: inputs_embeds = self.wte(input_ids) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + hidden_states = inputs_embeds - presents = () if use_cache else None + next_decoder_cache = None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Compute alibi tensor: check build_alibi_tensor documentation seq_length_with_past = seq_length past_key_values_length = 0 @@ -384,7 +393,7 @@ def forward( ) causal_mask = causal_mask.bool() - for block, layer_past in zip(self.blocks, past_key_values): + for block in self.blocks: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -394,23 +403,25 @@ def forward( hidden_states, alibi, causal_mask, - layer_past, + past_key_values, use_cache, output_attentions, + cache_position, ) else: outputs = block( hidden_states, - layer_past=layer_past, + layer_past=past_key_values, attention_mask=causal_mask, use_cache=use_cache, output_attentions=output_attentions, position_bias=alibi, + cache_position=cache_position, ) hidden_states = outputs[0] if use_cache is True: - presents = presents + (outputs[1],) + next_decoder_cache = outputs[1] if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) @@ -418,15 +429,21 @@ def forward( # Add last hidden state hidden_states = self.norm_f(hidden_states) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + return tuple( + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None + ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, ) @@ -467,6 +484,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" @@ -497,6 +515,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = transformer_outputs[0] diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 29874999e052..276d7ee5e40a 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, @@ -96,9 +97,10 @@ def __init__( self, embed_dim: int, num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + layer_idx: Optional[bool] = None, ): super().__init__() self.embed_dim = embed_dim @@ -113,24 +115,23 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, attn_prompt: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -142,35 +143,38 @@ def forward( # get query proj query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True if attn_prompt is not None: key_states = torch.cat([attn_prompt[0].expand(bsz, -1, -1, -1), key_states], dim=2) @@ -180,9 +184,10 @@ def forward( attention_mask = torch.cat([prompt_mask, attention_mask], dim=(-1)) proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) @@ -317,7 +322,7 @@ def forward( class MvpDecoderLayer(nn.Module): - def __init__(self, config: MvpConfig): + def __init__(self, config: MvpConfig, layer_idx=None): super().__init__() self.embed_dim = config.d_model @@ -326,6 +331,7 @@ def __init__(self, config: MvpConfig): num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -337,6 +343,7 @@ def __init__(self, config: MvpConfig): config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -353,9 +360,10 @@ def forward( cross_attn_layer_head_mask: Optional[torch.Tensor] = None, self_attn_prompt: Optional[torch.Tensor] = None, cross_attn_prompt: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -382,45 +390,37 @@ def forward( residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, attn_prompt=self_attn_prompt, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, attn_prompt=cross_attn_prompt, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) @@ -436,7 +436,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -744,7 +744,7 @@ def __init__( config.max_position_embeddings, config.d_model, ) - self.layers = nn.ModuleList([MvpDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([MvpDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.layernorm_embedding = nn.LayerNorm(config.d_model) self.use_prompt = use_prompt @@ -785,6 +785,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" Args: @@ -871,12 +872,27 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) @@ -902,18 +918,11 @@ def forward( self_attn_prompt = self.self_attn_prompt(prompt_ids) cross_attn_prompt = self.cross_attn_prompt(prompt_ids) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -933,8 +942,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, @@ -949,6 +956,7 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -962,14 +970,15 @@ def forward( ), self_attn_prompt=(self_attn_prompt[idx] if self.use_prompt else None), cross_attn_prompt=(cross_attn_prompt[idx] if self.use_prompt else None), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -982,6 +991,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -1055,6 +1067,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, Seq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1138,6 +1151,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1223,6 +1237,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1312,6 +1327,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 6d7bd6c985d1..b3c83153acc0 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -22,6 +22,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -515,6 +516,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[NllbMoeConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -531,6 +533,7 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -541,10 +544,11 @@ def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -565,42 +569,37 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == encoder_hidden_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `encoder_hidden_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == encoder_hidden_states.shape[1] - ): + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -708,7 +707,7 @@ def forward( class NllbMoeDecoderLayer(nn.Module): - def __init__(self, config: NllbMoeConfig, is_sparse: bool = False): + def __init__(self, config: NllbMoeConfig, is_sparse: bool = False, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model self.is_sparse = is_sparse @@ -730,6 +729,7 @@ def __init__(self, config: NllbMoeConfig, is_sparse: bool = False): config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.cross_attention_layer_norm = nn.LayerNorm(self.embed_dim) if not self.is_sparse: @@ -747,10 +747,11 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = True, ) -> torch.Tensor: """ Args: @@ -1120,7 +1121,7 @@ def __init__(self, config: NllbMoeConfig, embed_tokens: Optional[nn.Embedding] = self.layers = nn.ModuleList() for i in range(config.decoder_layers): is_sparse = (i + 1) % sparse_step == 0 if sparse_step > 0 else False - self.layers.append(NllbMoeDecoderLayer(config, is_sparse)) + self.layers.append(NllbMoeDecoderLayer(config, is_sparse, layer_idx=i)) self.layer_norm = nn.LayerNorm(config.d_model) @@ -1143,6 +1144,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = True, ): r""" Args: @@ -1230,12 +1232,28 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 attention_mask = self._update_causal_mask( attention_mask, input_shape, @@ -1257,19 +1275,12 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_probs = () if output_router_logits else None all_cross_attentions = () if output_attentions else None - present_key_value_states = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1293,8 +1304,6 @@ def forward( layer_head_mask = head_mask[idx] if head_mask is not None else None cross_attn_layer_head_mask = cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - past_key_value = past_key_values[idx] if past_key_values is not None else None - # under fsdp or deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: if use_cache: @@ -1313,6 +1322,7 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + cache_position, ) else: layer_outputs = decoder_layer( @@ -1322,10 +1332,11 @@ def forward( encoder_attention_mask=encoder_attention_mask, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_router_logits=output_router_logits, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -1334,7 +1345,7 @@ def forward( continue if use_cache: - present_key_value_states += (layer_outputs[1],) + next_decoder_cache = layer_outputs[1] if output_attentions: all_self_attns += (layer_outputs[2],) @@ -1349,12 +1360,16 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_self_attns, all_cross_attentions, @@ -1364,7 +1379,7 @@ def forward( ) return MoEModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -1503,6 +1518,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = True, ) -> Union[Tuple[torch.Tensor], Seq2SeqMoEModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1582,6 +1598,7 @@ def forward( output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1652,6 +1669,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqMoEOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1723,6 +1741,7 @@ def forward( output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 0f7cfcd224dc..121629e9de0f 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -26,6 +26,7 @@ from torch.nn import LayerNorm from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel @@ -488,11 +489,7 @@ def _forward(self, position_ids): class ProphetNetAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__( - self, - config: ProphetNetConfig, - num_attn_heads: int, - ): + def __init__(self, config: ProphetNetConfig, num_attn_heads: int, layer_idx: Optional[int] = None): super().__init__() hidden_size = config.hidden_size @@ -500,6 +497,7 @@ def __init__( self.dropout = config.dropout self.num_attn_heads = num_attn_heads self.head_dim = hidden_size // num_attn_heads + self.layer_idx = layer_idx assert self.head_dim * num_attn_heads == hidden_size, ( "`config.hidden_size` must be divisible by `config.num_encoder_attention_heads` and" @@ -512,17 +510,15 @@ def __init__( self.out_proj = nn.Linear(hidden_size, hidden_size) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states, key_value_states: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, layer_head_mask: Optional[Tensor] = None, - past_key_value: Optional[Tuple[Tensor]] = None, - output_attentions: bool = False, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: batch_size, tgt_len, hidden_size = hidden_states.size() @@ -537,32 +533,46 @@ def forward( # previous time steps are cached - no need to recompute key and value if they are static query_states = self.query_proj(hidden_states) / (self.head_dim**0.5) + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.key_proj(key_value_states), -1, batch_size) - value_states = self._shape(self.value_proj(key_value_states), -1, batch_size) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.key_proj(hidden_states), -1, batch_size) - value_states = self._shape(self.value_proj(hidden_states), -1, batch_size) - - if is_cross_attention: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - # project states into the correct shape - proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, batch_size).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + proj_shape = (batch_size * self.num_heads, -1, self.head_dim) + query_states = query_states.view(batch_size, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) src_len = key_states.size(2) attn_weights = torch.einsum("bsij,bsjk->bsik", query_states, key_states.transpose(2, 3)) expected_shape = (batch_size, self.num_attn_heads, tgt_len, src_len) @@ -638,7 +648,7 @@ def forward(self, hidden_states): class ProphetNetNgramSelfAttention(nn.Module): - def __init__(self, config: ProphetNetConfig): + def __init__(self, config: ProphetNetConfig, layer_idx=None): super().__init__() self.hidden_size = config.hidden_size @@ -649,6 +659,7 @@ def __init__(self, config: ProphetNetConfig): self.attention_dropout = config.attention_dropout self.head_dim = config.hidden_size // self.num_attn_heads self.ngram = config.ngram + self.layer_idx = layer_idx assert self.head_dim * self.num_attn_heads == config.hidden_size, ( "config.hidden_size must be divisible by num_attn_heads" @@ -683,6 +694,7 @@ def forward( main_relative_position_buckets=None, predict_relative_position_buckets=None, position_ids=None, + cache_position=None, ): batch_size, ngram_sequence_length, hidden_size = hidden_states.size() assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], ( @@ -721,13 +733,9 @@ def forward( # saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim) if past_key_value is not None: - prev_main_key_states = past_key_value[0] - main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=2) - prev_main_value_states = past_key_value[1] - main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=2) - - # Update cache - past_key_value = (main_key_states, main_value_states) + prev_main_key_states, main_value_states = past_key_value.update( + main_key_states, main_value_states, self.layer_idx, {"cache_position": cache_position} + ) # get seq_length of main stream only sequence_length = ngram_sequence_length // (1 + self.ngram) @@ -1004,15 +1012,15 @@ class ProphetNetDecoderLayer(nn.Module): Decoder block for Prophetnet """ - def __init__(self, config: ProphetNetConfig): + def __init__(self, config: ProphetNetConfig, layer_idx=None): super().__init__() # 1st residual block - self.self_attn = ProphetNetNgramSelfAttention(config) + self.self_attn = ProphetNetNgramSelfAttention(config, layer_idx=layer_idx) self.self_attn_layer_norm = LayerNorm(config.hidden_size) # 2nd residual block if config.add_cross_attention: - self.cross_attn = ProphetNetAttention(config, config.num_decoder_attention_heads) + self.cross_attn = ProphetNetAttention(config, config.num_decoder_attention_heads, layer_idx=layer_idx) self.cross_attn_layer_norm = LayerNorm(config.hidden_size) # 3rd residual block @@ -1032,15 +1040,14 @@ def forward( predict_relative_position_buckets=None, position_ids=None, past_key_value=None, - use_cache: bool = True, - output_attentions: bool = False, + use_cache: Optional[bool] = True, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ): # 1st residual block - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - ngram_attention_output, self_attn_weights, self_attn_weights_ngram, present_key_value = self.self_attn( + ngram_attention_output, self_attn_weights, self_attn_weights_ngram, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, extended_predict_attention_mask=extended_predict_attention_mask, @@ -1051,23 +1058,19 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states + ngram_attention_output) # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attn_weights = None if encoder_hidden_states is not None: # 2nd residual block - attention_output, cross_attn_weights, cross_attn_present_key_value = self.cross_attn( + attention_output, cross_attn_weights, past_key_value = self.cross_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attn_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # 3rd residual block feed_forward_output = self.feed_forward(hidden_states) hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states) @@ -1078,7 +1081,7 @@ def forward( outputs += (self_attn_weights, self_attn_weights_ngram, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -1242,7 +1245,9 @@ def __init__(self, config: ProphetNetConfig, word_embeddings: Optional[nn.Embedd self.position_embeddings = ProphetNetPositionalEmbeddings(config) self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None) - self.layers = nn.ModuleList([ProphetNetDecoderLayer(config) for _ in range(config.num_decoder_layers)]) + self.layers = nn.ModuleList( + [ProphetNetDecoderLayer(config, layer_idx=i) for i in range(config.num_decoder_layers)] + ) self.embeddings_layer_norm = LayerNorm(config.hidden_size) self.gradient_checkpointing = False @@ -1270,6 +1275,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, ProphetNetDecoderModelOutput]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): @@ -1313,7 +1319,26 @@ def forward( past_key_values=past_key_values, ) - if past_key_values is not None: + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values + + if past_key_values_length != 0: main_relative_position_buckets, predict_relative_position_buckets = None, None else: ( @@ -1328,7 +1353,7 @@ def forward( ngram_embeddings = self.ngram_embeddings.weight # prepare attention mask - if past_key_values is not None: + if past_key_values_length != 0: assert hidden_states.size(1) == 1, ( "At the moment `use_cache` is only supported for `decoder_input_ids` of length 1" ) @@ -1369,15 +1394,7 @@ def forward( all_main_stream_attns = () if output_attentions else None all_ngram_stream_attns = () if output_attentions else None all_cross_attns = () if output_attentions and self.config.add_cross_attention else None - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - present_key_values = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1393,8 +1410,6 @@ def forward( if self.config.ngram > 0: all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],) - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, @@ -1411,6 +1426,7 @@ def forward( None, use_cache, output_attentions, + cache_position, ) else: layer_outputs = decoder_layer( @@ -1426,15 +1442,16 @@ def forward( main_relative_position_buckets=main_relative_position_buckets, predict_relative_position_buckets=predict_relative_position_buckets, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - present_key_values += (layer_outputs[4 if output_attentions else 1],) + next_decoder_cache = layer_outputs[4 if output_attentions else 1] if output_attentions: all_main_stream_attns += (layer_outputs[1],) @@ -1448,6 +1465,10 @@ def forward( if self.config.ngram > 0: all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + # split last_hidden_state for return last_hidden_state = hidden_states[:, :sequence_length] last_hidden_state_ngram = hidden_states[:, sequence_length:] if self.config.ngram > 0 else None @@ -1458,7 +1479,7 @@ def forward( for v in [ last_hidden_state, last_hidden_state_ngram, - present_key_values, + next_cache, all_main_stream_hidden_states, all_ngram_stream_hidden_states, all_main_stream_attns, @@ -1470,7 +1491,7 @@ def forward( return ProphetNetDecoderModelOutput( last_hidden_state=last_hidden_state, last_hidden_state_ngram=last_hidden_state_ngram, - past_key_values=present_key_values, + past_key_values=next_cache, hidden_states=all_main_stream_hidden_states, hidden_states_ngram=all_ngram_stream_hidden_states, attentions=all_main_stream_attns, @@ -1618,6 +1639,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, ProphetNetSeq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1689,6 +1711,7 @@ def forward( output_hidden_states=output_hidden_states, use_cache=use_cache, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1759,6 +1782,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, ProphetNetSeq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1824,6 +1848,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) batch_size, sequence_length = ( decoder_input_ids.shape if decoder_input_ids is not None else decoder_inputs_embeds.shape[:2] diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index e2456afafa9d..2f9c3dbd9e20 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -21,6 +21,7 @@ import torch from torch import nn +from ...cache_utils import Cache from ...configuration_utils import PretrainedConfig from ...generation import GenerationConfig, GenerationMixin, LogitsProcessorList, StoppingCriteriaList from ...modeling_outputs import ModelOutput @@ -47,7 +48,7 @@ class RetrievAugLMMarginOutput(ModelOutput): doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`): Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and `question_encoder_last_hidden_state`. - past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). @@ -112,7 +113,7 @@ class RetrievAugLMMarginOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None doc_scores: Optional[torch.FloatTensor] = None - past_key_values: Optional[List[torch.FloatTensor]] = None + past_key_values: Optional[Cache] = None retrieved_doc_embeds: Optional[torch.FloatTensor] = None retrieved_doc_ids: Optional[torch.LongTensor] = None context_input_ids: Optional[torch.LongTensor] = None @@ -138,7 +139,7 @@ class RetrievAugLMOutput(ModelOutput): doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`): Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and `question_encoder_last_hidden_state`. - past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). @@ -202,7 +203,7 @@ class RetrievAugLMOutput(ModelOutput): logits: Optional[torch.FloatTensor] = None doc_scores: Optional[torch.FloatTensor] = None - past_key_values: Optional[List[torch.FloatTensor]] = None + past_key_values: Optional[Cache] = None retrieved_doc_embeds: Optional[torch.FloatTensor] = None retrieved_doc_ids: Optional[torch.LongTensor] = None context_input_ids: Optional[torch.LongTensor] = None @@ -435,7 +436,7 @@ def forward( encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, doc_scores: Optional[torch.FloatTensor] = None, context_input_ids: Optional[torch.LongTensor] = None, context_attention_mask: Optional[torch.LongTensor] = None, @@ -709,7 +710,7 @@ def forward( encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Cache] = None, context_input_ids: Optional[torch.LongTensor] = None, context_attention_mask: Optional[torch.LongTensor] = None, doc_scores: Optional[torch.FloatTensor] = None, @@ -1219,7 +1220,7 @@ def forward( encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Cache] = None, context_input_ids: Optional[torch.LongTensor] = None, context_attention_mask: Optional[torch.LongTensor] = None, doc_scores: Optional[torch.FloatTensor] = None, diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 6a8ba4bf1125..3f0ec7f03e0d 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -198,7 +199,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class RemBertSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -217,6 +218,7 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -230,8 +232,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple: mixed_query_layer = self.query(hidden_states) @@ -240,36 +243,38 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -318,9 +323,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class RemBertAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() - self.self = RemBertSelfAttention(config) + self.self = RemBertSelfAttention(config, layer_idx=layer_idx) self.output = RemBertSelfOutput(config) self.pruned_heads = set() @@ -351,8 +356,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -362,6 +368,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -400,17 +407,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class RemBertLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = RemBertAttention(config) + self.attention = RemBertAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = RemBertAttention(config) + self.crossattention = RemBertAttention(config, layer_idx=layer_idx) self.intermediate = RemBertIntermediate(config) self.output = RemBertOutput(config) @@ -422,28 +429,26 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -451,24 +456,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -476,7 +476,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -493,7 +493,7 @@ def __init__(self, config): self.config = config self.embedding_hidden_mapping_in = nn.Linear(config.input_embedding_size, config.hidden_size) - self.layer = nn.ModuleList([RemBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([RemBertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -508,6 +508,7 @@ def forward( output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: if self.gradient_checkpointing and self.training: if use_cache: @@ -515,18 +516,28 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False + + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + hidden_states = self.embedding_hidden_mapping_in(hidden_states) all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -536,7 +547,7 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, ) else: @@ -546,13 +557,13 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -561,12 +572,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -575,7 +590,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -710,6 +725,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -783,6 +799,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index a88cf0c9c63b..1d0b591be9eb 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( @@ -137,7 +138,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta class RobertaSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -162,6 +163,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -175,8 +177,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -185,43 +188,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -272,8 +276,8 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertSdpaSelfAttention with Bert->Roberta class RobertaSdpaSelfAttention(RobertaSelfAttention): - def __init__(self, config, position_embedding_type=None): - super().__init__(config, position_embedding_type=position_embedding_type) + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) self.dropout_prob = config.attention_probs_dropout_prob self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") @@ -285,13 +289,14 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. logger.warning_once( - "RobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "XLMRobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " "the manual attention implementation, but specifying the manual implementation will be required from " "Transformers version v5.0.0 onwards. This warning can be removed using the argument " @@ -305,6 +310,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) bsz, tgt_len, _ = hidden_states.size() @@ -318,25 +324,35 @@ def forward( current_states = encoder_hidden_states if is_cross_attention else hidden_states attention_mask = encoder_attention_mask if is_cross_attention else attention_mask - # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: - key_layer, value_layer = past_key_value - else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) - if past_key_value is not None and not is_cross_attention: - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] + else: + value_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. @@ -395,10 +411,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta,BERT->ROBERTA class RobertaAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = RobertaSelfOutput(config) self.pruned_heads = set() @@ -428,8 +446,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -439,6 +458,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -478,17 +498,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta class RobertaLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = RobertaAttention(config) + self.attention = RobertaAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = RobertaAttention(config, position_embedding_type="absolute") + self.crossattention = RobertaAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = RobertaIntermediate(config) self.output = RobertaOutput(config) @@ -499,28 +519,26 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -528,24 +546,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -553,7 +566,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -565,10 +578,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta class RobertaEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([RobertaLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -583,6 +596,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -595,13 +609,22 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -611,8 +634,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -621,13 +645,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -636,12 +661,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -650,7 +679,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -766,6 +795,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -877,6 +907,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 0f3aeaf4d04a..8cc11246f1c8 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -136,7 +137,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->RobertaPreLayerNorm class RobertaPreLayerNormSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -161,6 +162,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -174,8 +176,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -184,43 +187,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -366,17 +370,19 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->RobertaPreLayerNorm class RobertaPreLayerNormLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = RobertaPreLayerNormAttention(config) + self.attention = RobertaPreLayerNormAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = RobertaPreLayerNormAttention(config, position_embedding_type="absolute") + self.crossattention = RobertaPreLayerNormAttention( + config, position_embedding_type="absolute", layer_idx=layer_idx + ) self.intermediate = RobertaPreLayerNormIntermediate(config) self.output = RobertaPreLayerNormOutput(config) @@ -387,28 +393,26 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -416,24 +420,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -441,7 +440,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -453,10 +452,12 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->RobertaPreLayerNorm class RobertaPreLayerNormEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([RobertaPreLayerNormLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList( + [RobertaPreLayerNormLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) self.gradient_checkpointing = False def forward( @@ -471,6 +472,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -483,13 +485,22 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -499,8 +510,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -509,13 +521,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -524,12 +537,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -538,7 +555,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 42e9fb94e4e5..73e0f1ba54f5 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -251,7 +252,7 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->RoCBert class RoCBertSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -276,6 +277,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -289,8 +291,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -299,43 +302,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -406,10 +410,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->RoCBert,BERT->ROC_BERT class RoCBertAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = ROC_BERT_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = RoCBertSelfOutput(config) self.pruned_heads = set() @@ -439,8 +445,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -450,6 +457,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -489,17 +497,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->RoCBert class RoCBertLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = RoCBertAttention(config) + self.attention = RoCBertAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = RoCBertAttention(config, position_embedding_type="absolute") + self.crossattention = RoCBertAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = RoCBertIntermediate(config) self.output = RoCBertOutput(config) @@ -510,28 +518,26 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -539,24 +545,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -564,7 +565,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -576,10 +577,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->RoCBert class RoCBertEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([RoCBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([RoCBertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -594,6 +595,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -606,13 +608,22 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -622,8 +633,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -632,13 +644,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -647,12 +660,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -661,7 +678,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 1fdccf728463..2a8b77c6dbd5 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ...cache_utils import Cache, EncoderDecoderCache, DynamicCache from ...activations import ACT2FN, get_activation from ...generation import GenerationMixin from ...modeling_outputs import ( @@ -187,7 +188,7 @@ def forward(self, input_ids=None, token_type_ids=None, inputs_embeds=None): class RoFormerSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -207,6 +208,7 @@ def __init__(self, config): self.is_decoder = config.is_decoder self.rotary_value = config.rotary_value + self.layer_idx = layer_idx def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -223,6 +225,7 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): mixed_query_layer = self.query(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) @@ -231,39 +234,35 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - if sinusoidal_pos is not None: - if self.rotary_value: - query_layer, key_layer, value_layer = self.apply_rotary_position_embeddings( - sinusoidal_pos, query_layer, key_layer, value_layer - ) - else: - query_layer, key_layer = self.apply_rotary_position_embeddings( - sinusoidal_pos, query_layer, key_layer - ) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + if past_key_value is not None: - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -340,9 +339,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class RoFormerAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() - self.self = RoFormerSelfAttention(config) + self.self = RoFormerSelfAttention(config, layer_idx=layer_idx) self.output = RoFormerSelfOutput(config) self.pruned_heads = set() @@ -376,6 +375,7 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position, ): self_outputs = self.self( hidden_states, @@ -386,6 +386,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -424,17 +425,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class RoFormerLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = RoFormerAttention(config) + self.attention = RoFormerAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = RoFormerAttention(config) + self.crossattention = RoFormerAttention(config, layer_idx) self.intermediate = RoFormerIntermediate(config) self.output = RoFormerOutput(config) @@ -448,27 +449,25 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position, ): - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, sinusoidal_pos, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -476,8 +475,6 @@ def forward( "layers by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, @@ -485,16 +482,13 @@ def forward( head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -502,7 +496,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -519,7 +513,7 @@ def __init__(self, config): self.embed_positions = RoFormerSinusoidalPositionalEmbedding( config.max_position_embeddings, config.hidden_size // config.num_attention_heads ) - self.layer = nn.ModuleList([RoFormerLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([RoFormerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -534,6 +528,7 @@ def forward( output_attentions=False, output_hidden_states=False, return_dict=True, + cache_position=None, ): if self.gradient_checkpointing and self.training: if use_cache: @@ -541,16 +536,27 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False + + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 # [sequence_length, embed_size_per_head] -> [batch_size, num_heads, sequence_length, embed_size_per_head] sinusoidal_pos = self.embed_positions(hidden_states.shape[:-1], past_key_values_length)[None, None, :, :] - next_decoder_cache = () if use_cache else None + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -569,6 +575,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -580,11 +587,12 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -593,12 +601,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -607,7 +619,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -849,6 +861,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[BaseModelOutputWithPastAndCrossAttentions, Tuple[torch.Tensor]]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -921,6 +934,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] @@ -1076,6 +1090,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[CausalLMOutputWithCrossAttentions, Tuple[torch.Tensor]]: r""" @@ -1115,6 +1130,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = outputs[0] diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 42f3e4b577c3..e4fdf2e17efe 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -1036,16 +1037,14 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -1057,45 +1056,42 @@ def forward( # get query proj query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == encoder_hidden_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `encoder_hidden_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == encoder_hidden_states.shape[1] - ): + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(encoder_hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(encoder_hidden_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -1237,7 +1233,7 @@ def forward( class SeamlessM4TDecoderLayer(nn.Module): - def __init__(self, config: SeamlessM4TConfig, decoder_ffn_dim=None, decoder_attention_heads=None): + def __init__(self, config: SeamlessM4TConfig, decoder_ffn_dim=None, decoder_attention_heads=None, layer_idx=None): super().__init__() decoder_ffn_dim = config.decoder_ffn_dim if decoder_ffn_dim is None else decoder_ffn_dim decoder_attention_heads = ( @@ -1250,6 +1246,7 @@ def __init__(self, config: SeamlessM4TConfig, decoder_ffn_dim=None, decoder_atte num_heads=decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -1257,7 +1254,11 @@ def __init__(self, config: SeamlessM4TConfig, decoder_ffn_dim=None, decoder_atte self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.cross_attention = SeamlessM4TAttention( - self.embed_dim, decoder_attention_heads, config.attention_dropout, is_decoder=True + self.embed_dim, + decoder_attention_heads, + config.attention_dropout, + is_decoder=True, + layer_idx=layer_idx, ) self.cross_attention_layer_norm = nn.LayerNorm(self.embed_dim) @@ -1272,9 +1273,10 @@ def forward( attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -1298,41 +1300,33 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = self.attn_dropout(hidden_states) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.cross_attention_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attention( + hidden_states, cross_attn_weights, past_key_value = self.cross_attention( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, attention_mask=encoder_attention_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = self.attn_dropout(hidden_states) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value += cross_attn_present_key_value - # Fully Connected residual = hidden_states @@ -1343,7 +1337,7 @@ def forward( hidden_states = residual + hidden_states - outputs = (hidden_states, present_key_value) + outputs = (hidden_states, past_key_value) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) @@ -1763,12 +1757,13 @@ def __init__( ) layers = [] - for _ in range(config.decoder_layers): + for i in range(config.decoder_layers): layers.append( SeamlessM4TDecoderLayer( config, decoder_attention_heads=config.decoder_attention_heads, decoder_ffn_dim=config.decoder_ffn_dim, + layer_idx=i, ) ) self.layers = nn.ModuleList(layers) @@ -1797,6 +1792,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1818,12 +1814,28 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) @@ -1842,18 +1854,11 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -1864,8 +1869,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, @@ -1876,6 +1879,7 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -1883,14 +1887,15 @@ def forward( attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[1],) + next_decoder_cache = layer_outputs[1] if output_attentions: all_self_attns += (layer_outputs[2],) @@ -1905,6 +1910,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -1957,6 +1965,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1994,6 +2003,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -2082,6 +2092,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -2113,6 +2124,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index c62b97fb89ce..8587780aa683 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -939,42 +940,55 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, projection: torch.Tensor) -> torch.Tensor: - new_projection_shape = projection.size()[:-1] + (self.num_heads, self.head_dim) - # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) - new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) - return new_projection - def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" is_cross_attention = encoder_hidden_states is not None batch_size, seq_length = hidden_states.shape[:2] - # use encoder_hidden_states if cross attention - current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - # checking that the `sequence_length` of the `past_key_value` is the same as the he provided - # `encoder_hidden_states` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - key_states = self._shape(self.k_proj(current_states)) - value_states = self._shape(self.v_proj(current_states)) - if past_key_value is not None and not is_cross_attention: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - query_states = self._shape(self.q_proj(hidden_states) * self.scaling) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + query_states = self.q_proj(hidden_states) + query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states * self.scaling attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) if self.is_decoder: @@ -1095,7 +1109,9 @@ def forward( # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TDecoderLayer with SeamlessM4T->SeamlessM4Tv2 class SeamlessM4Tv2DecoderLayer(nn.Module): - def __init__(self, config: SeamlessM4Tv2Config, decoder_ffn_dim=None, decoder_attention_heads=None): + def __init__( + self, config: SeamlessM4Tv2Config, decoder_ffn_dim=None, decoder_attention_heads=None, layer_idx=None + ): super().__init__() decoder_ffn_dim = config.decoder_ffn_dim if decoder_ffn_dim is None else decoder_ffn_dim decoder_attention_heads = ( @@ -1108,6 +1124,7 @@ def __init__(self, config: SeamlessM4Tv2Config, decoder_ffn_dim=None, decoder_at num_heads=decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -1115,7 +1132,11 @@ def __init__(self, config: SeamlessM4Tv2Config, decoder_ffn_dim=None, decoder_at self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.cross_attention = SeamlessM4Tv2Attention( - self.embed_dim, decoder_attention_heads, config.attention_dropout, is_decoder=True + self.embed_dim, + decoder_attention_heads, + config.attention_dropout, + is_decoder=True, + layer_idx=layer_idx, ) self.cross_attention_layer_norm = nn.LayerNorm(self.embed_dim) @@ -1130,9 +1151,10 @@ def forward( attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -1156,41 +1178,33 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = self.attn_dropout(hidden_states) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.cross_attention_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attention( + hidden_states, cross_attn_weights, past_key_value = self.cross_attention( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, attention_mask=encoder_attention_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = self.attn_dropout(hidden_states) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value += cross_attn_present_key_value - # Fully Connected residual = hidden_states @@ -1201,7 +1215,7 @@ def forward( hidden_states = residual + hidden_states - outputs = (hidden_states, present_key_value) + outputs = (hidden_states, past_key_value) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) @@ -1832,12 +1846,13 @@ def __init__( ) layers = [] - for _ in range(config.decoder_layers): + for i in range(config.decoder_layers): layers.append( SeamlessM4Tv2DecoderLayer( config, decoder_attention_heads=config.decoder_attention_heads, decoder_ffn_dim=config.decoder_ffn_dim, + layer_idx=i, ) ) self.layers = nn.ModuleList(layers) @@ -1866,6 +1881,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1887,12 +1903,28 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) @@ -1911,18 +1943,11 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -1933,8 +1958,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, @@ -1945,6 +1968,7 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -1952,14 +1976,15 @@ def forward( attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[1],) + next_decoder_cache = layer_outputs[1] if output_attentions: all_self_attns += (layer_outputs[2],) @@ -1974,6 +1999,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -2330,6 +2358,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: r""" @@ -2359,6 +2388,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index aa4ea8107110..9c7530d70d5b 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -22,6 +22,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, @@ -399,7 +400,7 @@ def forward( # copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT # TODO: change copy when applying cache class class Speech2TextDecoderLayer(nn.Module): - def __init__(self, config: Speech2TextConfig): + def __init__(self, config: Speech2TextConfig, layer_idx=None): super().__init__() self.embed_dim = config.d_model @@ -410,6 +411,7 @@ def __init__(self, config: Speech2TextConfig): is_decoder=True, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -773,7 +775,9 @@ def __init__(self, config: Speech2TextConfig): self.padding_idx, ) - self.layers = nn.ModuleList([Speech2TextDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList( + [Speech2TextDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)] + ) self.layer_norm = nn.LayerNorm(config.d_model) @@ -801,6 +805,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): r""" Args: @@ -885,12 +890,27 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values attention_mask = self._update_causal_mask( attention_mask, input_shape, @@ -910,18 +930,11 @@ def forward( hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -939,8 +952,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, @@ -953,6 +964,7 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -964,14 +976,15 @@ def forward( cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -985,6 +998,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -1116,6 +1132,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`): @@ -1214,6 +1231,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1278,6 +1296,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`): @@ -1361,6 +1380,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 6498947325d5..f69eb7ad5f4c 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer @@ -93,7 +94,7 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Splinter class SplinterSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -118,6 +119,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -131,8 +133,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -141,43 +144,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -248,10 +252,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter,BERT->SPLINTER class SplinterAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = SPLINTER_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = SplinterSelfOutput(config) self.pruned_heads = set() @@ -281,8 +287,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -292,6 +299,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -331,17 +339,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Splinter class SplinterLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = SplinterAttention(config) + self.attention = SplinterAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = SplinterAttention(config, position_embedding_type="absolute") + self.crossattention = SplinterAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = SplinterIntermediate(config) self.output = SplinterOutput(config) @@ -352,28 +360,26 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -381,24 +387,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -406,7 +407,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -418,10 +419,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Splinter class SplinterEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([SplinterLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([SplinterLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -436,6 +437,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -448,13 +450,22 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -464,8 +475,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -474,13 +486,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -489,12 +502,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -503,7 +520,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/superglue/modeling_superglue.py b/src/transformers/models/superglue/modeling_superglue.py index 7b1f335f7f2d..651d442fa679 100644 --- a/src/transformers/models/superglue/modeling_superglue.py +++ b/src/transformers/models/superglue/modeling_superglue.py @@ -23,6 +23,7 @@ from transformers import PreTrainedModel from transformers.models.superglue.configuration_superglue import SuperGlueConfig +from ...cache_utils import Cache, EncoderDecoderCache from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging from ..auto import AutoModelForKeypointDetection @@ -233,7 +234,7 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->SuperGlue class SuperGlueSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -258,6 +259,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -271,8 +273,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -281,43 +284,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -383,10 +387,12 @@ def forward(self, hidden_states: torch.Tensor, *args) -> torch.Tensor: # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->SuperGlue,BERT->SUPERGLUE class SuperGlueAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = SUPERGLUE_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = SuperGlueSelfOutput(config) self.pruned_heads = set() @@ -416,8 +422,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -427,6 +434,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 35295ebf8890..d409f6742b74 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer @@ -427,8 +428,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -438,6 +440,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -498,28 +501,26 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -527,24 +528,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -552,7 +548,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index e9f3f63c965a..7df17973d723 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -22,6 +22,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions @@ -106,9 +107,10 @@ def __init__( self, embed_dim: int, num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + layer_idx: Optional[bool] = None, ): super().__init__() self.embed_dim = embed_dim @@ -123,23 +125,22 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -148,41 +149,46 @@ def forward( is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len # get query proj query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, src_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, src_len, -1, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) @@ -254,7 +260,7 @@ def forward( class XGLMDecoderLayer(nn.Module): - def __init__(self, config: XGLMConfig): + def __init__(self, config: XGLMConfig, layer_idx=None): super().__init__() self.embed_dim = config.d_model @@ -263,6 +269,7 @@ def __init__(self, config: XGLMConfig): num_heads=config.attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -274,6 +281,7 @@ def __init__(self, config: XGLMConfig): num_heads=config.attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -418,7 +426,7 @@ def __init__(self, config: XGLMConfig, embed_tokens: Optional[nn.Embedding] = No config.d_model, config.pad_token_id, ) - self.layers = nn.ModuleList([XGLMDecoderLayer(config) for _ in range(config.num_layers)]) + self.layers = nn.ModuleList([XGLMDecoderLayer(config, layer_idx=i) for i in range(config.num_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -447,6 +455,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): @@ -485,7 +494,31 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`..." + ) + use_cache = False + + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) if position_ids is None: position_ids = torch.arange( @@ -496,13 +529,6 @@ def forward( ) position_ids = position_ids.unsqueeze(0) - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -515,18 +541,11 @@ def forward( ) hidden_states = nn.functional.dropout(hidden_states, p=float(self.dropout), training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -545,8 +564,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, @@ -559,6 +576,7 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -570,14 +588,15 @@ def forward( cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -592,6 +611,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -654,6 +676,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" @@ -700,6 +723,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 6a702985dbbb..2a508ab915c7 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( @@ -138,7 +139,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->XLMRoberta class XLMRobertaSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -163,6 +164,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -176,8 +178,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -186,43 +189,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -273,8 +277,8 @@ def forward( # Copied from transformers.models.roberta.modeling_roberta.RobertaSdpaSelfAttention with Roberta->XLMRoberta class XLMRobertaSdpaSelfAttention(XLMRobertaSelfAttention): - def __init__(self, config, position_embedding_type=None): - super().__init__(config, position_embedding_type=position_embedding_type) + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) self.dropout_prob = config.attention_probs_dropout_prob self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") @@ -286,13 +290,14 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. logger.warning_once( - "XLMRobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "XLMXLMRobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " "the manual attention implementation, but specifying the manual implementation will be required from " "Transformers version v5.0.0 onwards. This warning can be removed using the argument " @@ -306,6 +311,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) bsz, tgt_len, _ = hidden_states.size() @@ -319,25 +325,35 @@ def forward( current_states = encoder_hidden_states if is_cross_attention else hidden_states attention_mask = encoder_attention_mask if is_cross_attention else attention_mask - # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: - key_layer, value_layer = past_key_value - else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) - if past_key_value is not None and not is_cross_attention: - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] + else: + value_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. @@ -396,10 +412,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->XLMRoberta,ROBERTA->XLM_ROBERTA class XLMRobertaAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = XLM_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = XLMRobertaSelfOutput(config) self.pruned_heads = set() @@ -429,8 +447,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -440,6 +459,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -479,17 +499,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->XLMRoberta class XLMRobertaLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = XLMRobertaAttention(config) + self.attention = XLMRobertaAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = XLMRobertaAttention(config, position_embedding_type="absolute") + self.crossattention = XLMRobertaAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = XLMRobertaIntermediate(config) self.output = XLMRobertaOutput(config) @@ -500,28 +520,26 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -529,24 +547,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -554,7 +567,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -566,10 +579,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->XLMRoberta class XLMRobertaEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([XLMRobertaLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([XLMRobertaLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -584,6 +597,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -596,13 +610,22 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -612,8 +635,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -622,13 +646,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -637,12 +662,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -651,7 +680,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -757,6 +786,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -868,6 +898,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 6040e4002158..ec5b98013bc8 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( @@ -135,7 +136,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->XLMRobertaXL class XLMRobertaXLSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -160,6 +161,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -173,8 +175,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -183,43 +186,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -270,8 +274,8 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertSdpaSelfAttention with Bert->XLMRobertaXL class XLMRobertaXLSdpaSelfAttention(XLMRobertaXLSelfAttention): - def __init__(self, config, position_embedding_type=None): - super().__init__(config, position_embedding_type=position_embedding_type) + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) self.dropout_prob = config.attention_probs_dropout_prob self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") @@ -283,13 +287,14 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. logger.warning_once( - "XLMRobertaXLSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "XLMRobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " "the manual attention implementation, but specifying the manual implementation will be required from " "Transformers version v5.0.0 onwards. This warning can be removed using the argument " @@ -303,6 +308,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) bsz, tgt_len, _ = hidden_states.size() @@ -316,25 +322,35 @@ def forward( current_states = encoder_hidden_states if is_cross_attention else hidden_states attention_mask = encoder_attention_mask if is_cross_attention else attention_mask - # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: - key_layer, value_layer = past_key_value - else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) - if past_key_value is not None and not is_cross_attention: - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] + else: + value_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. @@ -749,6 +765,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -860,6 +877,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 646c952fdec0..d9746623afaf 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -135,7 +136,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->Xmod class XmodSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -160,6 +161,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -173,8 +175,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -183,43 +186,44 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) From 5c5825bc1de0e176047c7d183869337068c815b1 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 5 Jun 2025 12:59:33 +0200 Subject: [PATCH 02/58] push other models --- src/transformers/models/bark/modeling_bark.py | 98 ++++++----- .../models/big_bird/modeling_big_bird.py | 126 ++++++++------ .../models/blip/modeling_blip_text.py | 112 +++++++----- .../models/imagegpt/modeling_imagegpt.py | 124 +++++++++----- .../models/kosmos2/modeling_kosmos2.py | 158 +++++++++-------- src/transformers/models/led/modeling_led.py | 3 - .../models/speecht5/modeling_speecht5.py | 161 +++++++++++------- .../models/trocr/modeling_trocr.py | 153 +++++++++-------- 8 files changed, 547 insertions(+), 388 deletions(-) diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 7a564393f9e3..7cb09c0d95ef 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import functional as F +from ...cache_utils import DynamicCache from ...generation import GenerationMixin from ...generation.logits_process import ( AlternatingCodebooksLogitsProcessor, @@ -65,7 +66,7 @@ class BarkSelfAttention(nn.Module): # adapted from GPTNeoSelfAttention and Bark code # BarkSelfAttention can have two attention type, i.e full attention or causal attention - def __init__(self, config, is_causal=False): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() # regularization @@ -89,6 +90,7 @@ def __init__(self, config, is_causal=False): self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.bias) self.is_causal = is_causal + self.layer_idx = layer_idx if is_causal: block_size = config.block_size bias = torch.tril(torch.ones((block_size, block_size), dtype=bool)).view(1, 1, block_size, block_size) @@ -154,6 +156,7 @@ def forward( head_mask=None, use_cache=False, output_attentions=False, + cache_position=None, ): # calculate query, key, values for all heads in batch and move head forward to be the batch dim query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2) @@ -163,15 +166,7 @@ def forward( value = self._split_heads(value, self.num_heads, self.head_dim) if past_key_values is not None: - past_key = past_key_values[0] - past_value = past_key_values[1] - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) - - if use_cache is True: - present = (key, value) - else: - present = None + key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position}) attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) @@ -179,7 +174,7 @@ def forward( attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, present) + outputs = (attn_output, past_key_values) if output_attentions: outputs += (attn_weights,) @@ -228,6 +223,7 @@ def forward( head_mask=None, use_cache=False, output_attentions=False, + cache_position=None, ): batch_size, query_len, _ = hidden_states.size() @@ -239,18 +235,7 @@ def forward( value = self._split_heads(value, self.num_heads, self.head_dim) if past_key_values is not None: - # (batch, head, seq_length, head_features) -> (batch, seq_length, head, head_features) - past_key = past_key_values[0].transpose(1, 2) - past_value = past_key_values[1].transpose(1, 2) - # and merge on seq_length - key = torch.cat((past_key, key), dim=1) - value = torch.cat((past_value, value), dim=1) - - if use_cache is True: - # (batch, head, seq_length, head_features) - present = (key.transpose(1, 2), value.transpose(1, 2)) - else: - present = None + key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position}) attn_output = _flash_attention_forward( query, @@ -267,7 +252,7 @@ def forward( attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, present) + outputs = (attn_output, cache_position) if output_attentions: attn_weights = None outputs += (attn_weights,) @@ -310,7 +295,7 @@ def forward(self, hidden_states): class BarkBlock(nn.Module): - def __init__(self, config, is_causal=False): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if is_causal: @@ -323,7 +308,9 @@ def __init__(self, config, is_causal=False): self.layernorm_1 = nn.LayerNorm(config.hidden_size) self.layernorm_2 = nn.LayerNorm(config.hidden_size) - self.attn = BARK_ATTENTION_CLASSES[config._attn_implementation](config, is_causal=is_causal) + self.attn = BARK_ATTENTION_CLASSES[config._attn_implementation]( + config, is_causal=is_causal, layer_idx=layer_idx + ) self.mlp = BarkMLP(config) @@ -335,6 +322,7 @@ def forward( head_mask=None, use_cache=False, output_attentions=False, + cache_position=None, ): intermediary_hidden_states = self.layernorm_1(hidden_states) @@ -345,6 +333,7 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) attn_output = attn_outputs[0] # output_attn: output, present_key_values, (attn_weights) @@ -423,7 +412,7 @@ def __init__(self, config): self.drop = nn.Dropout(config.dropout) - self.layers = nn.ModuleList([BarkBlock(config, is_causal=True) for _ in range(config.num_layers)]) + self.layers = nn.ModuleList([BarkBlock(config, is_causal=True, layer_idx=i) for i in range(config.num_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.layernorm_final = BarkLayerNorm(config.hidden_size, bias=config.bias) @@ -450,7 +439,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg if past_key_values is not None: # Omit tokens covered by past_key_values seq_len = input_ids.shape[1] - past_length = past_key_values[0][0].shape[2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: @@ -516,6 +505,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]: r""" input_embeds (`torch.FloatTensor` of shape `(batch_size, input_sequence_length, hidden_size)`, *optional*): @@ -558,11 +548,24 @@ def forward( device = input_ids.device if input_ids is not None else input_embeds.device - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.layers)) - else: - past_length = past_key_values[0][0].size(-2) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values if position_ids is None: position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device) @@ -591,18 +594,11 @@ def forward( hidden_states = self.drop(input_embeds + position_embeds) output_shape = input_shape + (hidden_states.size(-1),) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - present_key_values = () if use_cache else None + next_decoder_cache = None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - for i, (block, past_layer_key_values) in enumerate(zip(self.layers, past_key_values)): + for i, block in enumerate(self.layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -615,21 +611,23 @@ def forward( head_mask[i], use_cache, output_attentions, + cache_position, ) else: outputs = block( hidden_states, - past_key_values=past_layer_key_values, + past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = outputs[0] if use_cache: - present_key_values = present_key_values + (outputs[1],) + next_decoder_cache = outputs[1] if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) @@ -644,15 +642,19 @@ def forward( logits = self.lm_head(hidden_states) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( - v for v in [None, logits, present_key_values, all_hidden_states, all_self_attentions] if v is not None + v for v in [None, logits, next_cache, all_hidden_states, all_self_attentions] if v is not None ) return CausalLMOutputWithPast( loss=loss, logits=logits, - past_key_values=present_key_values, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, ) @@ -1030,7 +1032,9 @@ def __init__(self, config): self.drop = nn.Dropout(config.dropout) - self.layers = nn.ModuleList([BarkBlock(config, is_causal=False) for _ in range(config.num_layers)]) + self.layers = nn.ModuleList( + [BarkBlock(config, is_causal=False, layer_idx=i) for i in range(config.num_layers)] + ) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.layernorm_final = nn.LayerNorm(config.hidden_size) diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 2106c07e7dfd..734c49974d6c 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -295,7 +296,7 @@ def forward( class BigBirdSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -313,6 +314,7 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -328,6 +330,7 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): mixed_query_layer = self.query(hidden_states) @@ -336,36 +339,38 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -1309,7 +1314,7 @@ def __init__(self, config, seed=None): self.seed = seed if self.config.attention_type == "original_full": - self.self = BigBirdSelfAttention(config) + self.self = BigBirdSelfAttention(config, layer_idx=seed) elif self.config.attention_type == "block_sparse": self.self = BigBirdBlockSparseAttention(config, seed) else: @@ -1319,7 +1324,7 @@ def __init__(self, config, seed=None): self.output = BigBirdSelfOutput(config) - def set_attention_type(self, value: str): + def set_attention_type(self, value: str, layer_idx=None): if value not in ["original_full", "block_sparse"]: raise ValueError( f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}" @@ -1331,7 +1336,7 @@ def set_attention_type(self, value: str): self.attention_type = value if value == "original_full": # copy all weights to new full attention class - attn_weights = BigBirdSelfAttention(self.config) + attn_weights = BigBirdSelfAttention(self.config, layer_idx=layer_idx) else: # copy all weights to new sparse attention class attn_weights = BigBirdBlockSparseAttention(self.config, self.seed) @@ -1358,6 +1363,7 @@ def forward( to_mask=None, from_blocked_mask=None, to_blocked_mask=None, + cache_position=None, ): # fp16 compatibility if band_mask is not None: @@ -1375,6 +1381,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) else: if encoder_hidden_states is not None: @@ -1426,17 +1433,17 @@ def __init__(self, config, seed=None): self.attention_type = config.attention_type self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = BigBirdAttention(config, seed=seed) + self.attention = BigBirdAttention(config, seed=seed, layer_idx=seed) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise TypeError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = BigBirdAttention(config) + self.crossattention = BigBirdAttention(config, layer_idx=seed) self.intermediate = BigBirdIntermediate(config) self.output = BigBirdOutput(config) - def set_attention_type(self, value: str): + def set_attention_type(self, value: str, layer_idx=None): if value not in ["original_full", "block_sparse"]: raise ValueError( f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}" @@ -1445,10 +1452,10 @@ def set_attention_type(self, value: str): if value == self.attention_type: return self.attention_type = value - self.attention.set_attention_type(value) + self.attention.set_attention_type(value, layer_idx=layer_idx) if self.add_cross_attention: - self.crossattention.set_attention_type(value) + self.crossattention.set_attention_type(value, layer_idx=layer_idx) def forward( self, @@ -1463,33 +1470,32 @@ def forward( blocked_encoder_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, band_mask=band_mask, from_mask=from_mask, to_mask=to_mask, from_blocked_mask=blocked_encoder_mask, to_blocked_mask=blocked_encoder_mask, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -1497,24 +1503,19 @@ def forward( " cross-attention layers by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -1523,7 +1524,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -1553,8 +1554,8 @@ def set_attention_type(self, value: str): if value == self.attention_type: return self.attention_type = value - for layer in self.layer: - layer.set_attention_type(value) + for i, layer in enumerate(self.layer): + layer.set_attention_type(value, layer_idx=i) def forward( self, @@ -1572,6 +1573,7 @@ def forward( to_mask=None, blocked_encoder_mask=None, return_dict=True, + cache_position=None, ) -> Union[BaseModelOutputWithPastAndCrossAttentions, Tuple]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -1584,7 +1586,17 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: @@ -1607,6 +1619,7 @@ def forward( blocked_encoder_mask, past_key_value, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -1621,11 +1634,12 @@ def forward( blocked_encoder_mask, past_key_value, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -1634,12 +1648,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -1648,7 +1666,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -1905,6 +1923,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, # NOOP kwargs, for now ) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, Tuple[torch.FloatTensor]]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -2044,6 +2063,7 @@ def forward( to_mask=to_mask, blocked_encoder_mask=blocked_encoder_mask, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] @@ -2433,6 +2453,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[CausalLMOutputWithCrossAttentions, Tuple[torch.FloatTensor]]: r""" @@ -2457,6 +2478,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index ffbca32eb9d8..29425b127e5b 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -23,6 +23,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -97,7 +98,7 @@ def forward( # Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L97 class BlipTextSelfAttention(nn.Module): - def __init__(self, config, is_cross_attention): + def __init__(self, config, is_cross_attention, layer_idx=None): super().__init__() self.config = config if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): @@ -109,6 +110,7 @@ def __init__(self, config, is_cross_attention): self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size + self.layer_idx = layer_idx self.query = nn.Linear(config.hidden_size, self.all_head_size) if is_cross_attention: @@ -136,11 +138,6 @@ def save_attention_map(self, attention_map): def get_attention_map(self): return self.attention_map - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -148,8 +145,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -158,23 +156,38 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(current_states)) + value_layer = self.transpose_for_scores(self.v_proj(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -239,9 +252,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#242 class BlipTextAttention(nn.Module): - def __init__(self, config, is_cross_attention=False): + def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() - self.self = BlipTextSelfAttention(config, is_cross_attention) + self.self = BlipTextSelfAttention(config, is_cross_attention, layer_idx=layer_idx) self.output = BlipTextSelfOutput(config) self.pruned_heads = set() @@ -270,8 +283,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, @@ -281,6 +295,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -327,7 +342,9 @@ def __init__(self, config, layer_num): self.attention = BlipTextAttention(config) self.layer_num = layer_num if self.config.is_decoder: - self.crossattention = BlipTextAttention(config, is_cross_attention=self.config.is_decoder) + self.crossattention = BlipTextAttention( + config, is_cross_attention=self.config.is_decoder, layer_idx=layer_num + ) self.intermediate = BlipTextIntermediate(config) self.output = BlipTextOutput(config) @@ -338,22 +355,20 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] if encoder_hidden_states is not None: cross_attention_outputs = self.crossattention( @@ -362,7 +377,9 @@ def forward( head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights @@ -371,7 +388,7 @@ def forward( ) outputs = (layer_output,) + outputs - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -401,6 +418,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: if self.gradient_checkpointing and self.training: if use_cache: @@ -408,11 +426,21 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False + + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.is_decoder else None - - next_decoder_cache = () if use_cache else None + next_decoder_cache = None for i in range(self.config.num_hidden_layers): layer_module = self.layer[i] @@ -420,7 +448,6 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -428,13 +455,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) all_cross_attentions = all_cross_attentions + (layer_outputs[2],) @@ -442,12 +470,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -456,7 +488,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -671,6 +703,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, is_decoder: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: r""" encoder_hidden_states (`torch.FloatTensor`, *optional*): @@ -778,6 +811,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None @@ -835,6 +869,7 @@ def forward( return_logits: Optional[bool] = False, is_decoder: Optional[bool] = True, reduction: Optional[str] = "mean", + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" encoder_hidden_states (`torch.FloatTensor`, *optional*): Sequence of @@ -876,6 +911,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, is_decoder=is_decoder, + cache_position=cache_position, ) sequence_output = outputs[0] diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 68f6a04c5a2a..7562d3532049 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -334,40 +335,61 @@ def _merge_heads(self, tensor, num_heads, attn_head_size): def forward( self, hidden_states: torch.Tensor, - layer_past: Optional[bool] = None, + layer_past: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple: - if encoder_hidden_states is not None: + is_cross_attention = encoder_hidden_states is not None + bsz, seq_len, _ = hidden_states.shape + + if layer_past is not None: + if isinstance(layer_past, EncoderDecoderCache): + is_updated = layer_past.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = layer_past.cross_attention_cache + else: + curr_past_key_value = layer_past.self_attention_cache + else: + curr_past_key_value = layer_past + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention: if not hasattr(self, "q_attn"): raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " "Please make sure to instantiate class with `ImageGPTAttention(..., is_cross_attention=True)`." ) - query = self.q_attn(hidden_states) - key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) - attention_mask = encoder_attention_mask + if layer_past is not None and is_updated: + # reuse k,v, cross_attentions, and compute only q + query = query = self.q_attn(hidden_states) + key = curr_past_key_value.key_cache[self.layer_idx] + value = curr_past_key_value.value_cache[self.layer_idx] + else: + query = query = self.q_attn(hidden_states) + key, value = self.c_attn(current_states).split(self.split_size, dim=2) + key = key.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value = value.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) else: - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) + query, key, value = self.c_attn(current_states).split(self.split_size, dim=2) + key = key.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value = value.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) if layer_past is not None: - past_key, past_value = layer_past - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key, value = curr_past_key_value.update(key, value, self.layer_idx, {"cache_position": cache_position}) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + layer_past.is_updated[self.layer_idx] = True - if use_cache is True: - present = (key, value) - else: - present = None + query = query.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) if self.reorder_and_upcast_attn: attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) @@ -378,11 +400,11 @@ def forward( attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, present) + outputs = (attn_output, layer_past) if output_attentions: outputs += (attn_weights,) - return outputs # a, present, (attentions) + return outputs # a, layer_past, (attentions) class ImageGPTMLP(nn.Module): @@ -421,13 +443,14 @@ def __init__(self, config, layer_idx=None): def forward( self, hidden_states: torch.Tensor, - layer_past: Optional[bool] = None, + layer_past: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple: residual = hidden_states hidden_states = self.ln_1(hidden_states) @@ -438,6 +461,7 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] @@ -455,11 +479,13 @@ def forward( hidden_states = self.ln_cross_attn(hidden_states) cross_attn_outputs = self.crossattention( hidden_states, + layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, + cache_position=cache_position, ) attn_output = cross_attn_outputs[0] # residual connection @@ -566,6 +592,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs: Any, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" @@ -632,14 +659,28 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values + if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) if position_ids is None: position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0) @@ -691,27 +732,16 @@ def forward( hidden_states = hidden_states + token_type_embeds hidden_states = self.drop(hidden_states) - output_shape = input_shape + (hidden_states.size(-1),) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None + next_decoder_cache = None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + for i, block in enumerate(self.h): # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) # Ensure that attention_mask is always on the same device as hidden_states if attention_mask is not None: attention_mask = attention_mask.to(hidden_states.device) @@ -731,22 +761,24 @@ def forward( encoder_attention_mask, use_cache, output_attentions, + cache_position, ) else: outputs = block( hidden_states, - layer_past=layer_past, + layer_past=past_key_values, attention_mask=attention_mask, head_mask=head_mask[i], encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = outputs[0] if use_cache is True: - presents = presents + (outputs[1],) + next_decoder_cache = outputs[1] if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) @@ -760,22 +792,26 @@ def forward( hidden_states = hidden_states.to("cuda:" + str(k + 1)) hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(*output_shape) + # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v - for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -825,6 +861,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs: Any, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" @@ -907,6 +944,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = transformer_outputs[0] diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index d25f100a2348..32b707845041 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -23,6 +23,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -706,9 +707,10 @@ def __init__( embed_dim: int, num_heads: int, dropout: float = 0.0, - is_decoder: bool = False, - add_inner_attn_layernorm: bool = False, - bias: bool = True, + is_decoder: Optional[bool] = False, + add_inner_attn_layernorm: Optional[bool] = False, + bias: Optional[bool] = True, + layer_idx: Optional[bool] = None, ): super().__init__() self.config = config @@ -724,6 +726,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -735,22 +738,17 @@ def __init__( if add_inner_attn_layernorm: self.inner_attn_ln = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - def _shape(self, projection: torch.Tensor) -> torch.Tensor: - new_projection_shape = projection.size()[:-1] + (self.num_heads, self.head_dim) - # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) - new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) - return new_projection - def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer @@ -758,33 +756,40 @@ def forward( is_cross_attention = encoder_hidden_states is not None batch_size, seq_length = hidden_states.shape[:2] - # use encoder_hidden_states if cross attention - current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - # checking that the `sequence_length` of the `past_key_value` is the same as the he provided - # `encoder_hidden_states` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + query_states = self.q_proj(hidden_states) + query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - key_states = self._shape(self.k_proj(current_states)) - value_states = self._shape(self.v_proj(current_states)) - if past_key_value is not None and not is_cross_attention: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - query_states = self._shape(self.q_proj(hidden_states)) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward @@ -841,7 +846,7 @@ def forward(self, hidden_states): class Kosmos2TextBlock(nn.Module): - def __init__(self, config: Kosmos2TextConfig): + def __init__(self, config: Kosmos2TextConfig, layer_idx=None): super().__init__() self.embed_dim = config.embed_dim @@ -852,6 +857,7 @@ def __init__(self, config: Kosmos2TextConfig): dropout=config.attention_dropout, is_decoder=True, add_inner_attn_layernorm=True, + layer_idx=layer_idx, ) self.dropout = config.dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -864,6 +870,7 @@ def __init__(self, config: Kosmos2TextConfig): dropout=config.attention_dropout, is_decoder=True, add_inner_attn_layernorm=False, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -878,33 +885,29 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - - # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - hidden_states = self.self_attn_layer_norm(hidden_states) # add present self-attn cache to positions 1,2 of present_key_value tuple hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: if not hasattr(self, "encoder_attn"): @@ -914,26 +917,21 @@ def forward( ) residual = hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states @@ -949,7 +947,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -977,7 +975,7 @@ def __init__(self, config: Kosmos2TextConfig): padding_idx=config.pad_token_id, ) - self.layers = nn.ModuleList([Kosmos2TextBlock(config) for _ in range(config.layers)]) + self.layers = nn.ModuleList([Kosmos2TextBlock(config, layer_idx=i) for i in range(config.layers)]) self.layer_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps) self.gradient_checkpointing = False @@ -1058,6 +1056,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1077,8 +1076,24 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values # We don't need img info. when `past_key_values_length` > 0 if past_key_values_length > 0: @@ -1105,18 +1120,11 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - present_key_value_states = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1136,8 +1144,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, @@ -1150,6 +1156,7 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -1161,15 +1168,16 @@ def forward( cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, **kwargs, ) hidden_states = layer_outputs[0] if use_cache: - present_key_value_states += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1180,13 +1188,17 @@ def forward( # add final layer norm hidden_states = self.layer_norm(hidden_states) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -1353,6 +1365,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" @@ -1386,6 +1399,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, **kwargs, ) @@ -1444,6 +1458,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" @@ -1488,6 +1503,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, + cache_position=cache_position, **kwargs, ) lm_logits = self.lm_head(outputs[0]) @@ -1525,9 +1541,11 @@ def prepare_inputs_for_generation( past_key_values_length=0, ) - if past_key_values is not None: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + if cache_position[0] == 0: image_embeds = None image_embeds_position_mask = None + # appending `False` to `image_embeds_position_mask` (because `input_ids` grows during generation) elif image_embeds_position_mask is not None: batch_size, seq_len = input_ids.size() diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 2b2d17f28d8a..ddfd7097c37d 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -789,9 +789,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 5004dd037c4d..e394137204be 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss from ...activations import ACT2FN +from ...cache_utiils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -858,9 +859,10 @@ def __init__( self, embed_dim: int, num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + layer_idx: Optional[bool] = None, ): super().__init__() self.embed_dim = embed_dim @@ -875,25 +877,24 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, position_bias: Optional[torch.Tensor] = None, output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + cache_position: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer @@ -904,40 +905,44 @@ def forward( # get query proj query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) @@ -1094,13 +1099,14 @@ def forward( class SpeechT5DecoderLayer(nn.Module): - def __init__(self, config: SpeechT5Config): + def __init__(self, config: SpeechT5Config, layer_idx=None): super().__init__() self.self_attn = SpeechT5Attention( embed_dim=config.hidden_size, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.dropout = nn.Dropout(config.hidden_dropout) self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -1110,6 +1116,7 @@ def __init__(self, config: SpeechT5Config): config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -1124,9 +1131,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ): """ Args: @@ -1149,43 +1157,36 @@ def forward( residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = self.dropout(hidden_states) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = self.dropout(hidden_states) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected hidden_states = hidden_states + self.feed_forward(hidden_states) hidden_states = self.final_layer_norm(hidden_states) @@ -1196,7 +1197,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -1496,7 +1497,7 @@ def __init__(self, config: SpeechT5Config): super().__init__(config) self.layerdrop = config.decoder_layerdrop - self.layers = nn.ModuleList([SpeechT5DecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([SpeechT5DecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.gradient_checkpointing = False @@ -1516,6 +1517,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" Args: @@ -1585,7 +1587,24 @@ def forward( input_shape = hidden_states.size()[:-1] - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, hidden_states, past_key_values_length @@ -1600,18 +1619,11 @@ def forward( synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1634,8 +1646,6 @@ def forward( if skip_the_layer and not synced_gpus: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, @@ -1648,6 +1658,7 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -1659,14 +1670,15 @@ def forward( cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) @@ -1678,6 +1690,9 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -1722,6 +1737,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: decoder_hidden_states = self.prenet(input_values, speaker_embeddings) @@ -1737,6 +1753,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) return outputs @@ -1774,6 +1791,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: decoder_hidden_states, attention_mask = self.prenet(input_values, attention_mask, past_key_values) @@ -1789,6 +1807,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) return outputs @@ -1820,6 +1839,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: outputs = self.wrapped_decoder( hidden_states=input_values, @@ -1833,6 +1853,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) return outputs @@ -2027,6 +2048,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`): @@ -2101,6 +2123,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, **decoder_args, ) @@ -2183,6 +2206,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, Seq2SeqLMOutput]: r""" input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): @@ -2277,6 +2301,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, + cache_position=cache_position, ) logits = self.text_decoder_postnet(outputs[0]) @@ -2515,6 +2540,7 @@ def forward( speaker_embeddings: Optional[torch.FloatTensor] = None, labels: Optional[torch.FloatTensor] = None, stop_labels: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, Seq2SeqSpectrogramOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -2597,6 +2623,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, + cache_position=cache_position, ) outputs_before_postnet, outputs_after_postnet, logits = self.speech_decoder_postnet(outputs[0]) @@ -2868,6 +2895,7 @@ def forward( speaker_embeddings: Optional[torch.FloatTensor] = None, labels: Optional[torch.FloatTensor] = None, stop_labels: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, Seq2SeqSpectrogramOutput]: r""" input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): @@ -2954,6 +2982,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, + cache_position=cache_position, ) _, spectrogram, logits = self.speech_decoder_postnet(outputs[0]) diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 152243da0845..35169d02d296 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -23,6 +23,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, @@ -149,10 +150,11 @@ def __init__( num_heads: int, kdim: Optional[int] = None, vdim: Optional[int] = None, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - is_cross_attention: bool = False, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + is_cross_attention: Optional[bool] = False, + layer_idx: Optional[bool] = None, ): super().__init__() self.embed_dim = embed_dim @@ -168,6 +170,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.layer_idx = layer_idx self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) @@ -175,17 +178,15 @@ def __init__( self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -196,40 +197,44 @@ def forward( # get query proj query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) @@ -289,7 +294,7 @@ def forward( class TrOCRDecoderLayer(nn.Module): - def __init__(self, config: TrOCRConfig): + def __init__(self, config: TrOCRConfig, layer_idx=None): super().__init__() self.embed_dim = config.hidden_size @@ -299,6 +304,7 @@ def __init__(self, config: TrOCRConfig): num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -316,6 +322,7 @@ def __init__(self, config: TrOCRConfig): dropout=config.attention_dropout, is_decoder=True, is_cross_attention=True, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -331,9 +338,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ): """ Args: @@ -356,15 +364,13 @@ def forward( residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -372,30 +378,24 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None - if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) @@ -412,7 +412,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -469,7 +469,7 @@ def __init__(self, config: TrOCRConfig): else: self.layernorm_embedding = None - self.layers = nn.ModuleList([TrOCRDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([TrOCRDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -495,6 +495,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): r""" Args: @@ -580,8 +581,24 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and isinstance(past_key_values, (list, tuple)): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -611,18 +628,11 @@ def forward( encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] ) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -641,8 +651,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, @@ -655,6 +663,7 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -666,14 +675,15 @@ def forward( cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -686,6 +696,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -771,6 +784,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): @@ -851,6 +865,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.output_projection(outputs[0]) From 051fe7f88d1ce349f533f346d60e14cb1901b373 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 5 Jun 2025 15:59:42 +0200 Subject: [PATCH 03/58] fix simple greedy generation --- .../models/align/modeling_align.py | 2 +- .../models/altclip/modeling_altclip.py | 2 +- src/transformers/models/bark/modeling_bark.py | 6 +- src/transformers/models/bert/modeling_bert.py | 8 +- .../modeling_bert_generation.py | 8 +- .../models/big_bird/modeling_big_bird.py | 4 +- .../models/blip/modeling_blip_text.py | 2 +- .../bridgetower/modeling_bridgetower.py | 2 +- .../models/camembert/modeling_camembert.py | 2 +- .../chinese_clip/modeling_chinese_clip.py | 2 +- src/transformers/models/clap/modeling_clap.py | 2 +- src/transformers/models/clvp/modeling_clvp.py | 6 +- .../models/cpmant/modeling_cpmant.py | 2 +- src/transformers/models/ctrl/modeling_ctrl.py | 4 +- .../models/data2vec/modeling_data2vec_text.py | 8 +- .../models/electra/modeling_electra.py | 2 +- .../models/ernie/modeling_ernie.py | 8 +- .../models/imagegpt/modeling_imagegpt.py | 2 +- .../models/kosmos2/modeling_kosmos2.py | 4 +- .../models/layoutlm/modeling_layoutlm.py | 2 +- src/transformers/models/led/modeling_led.py | 6 +- .../models/markuplm/modeling_markuplm.py | 2 +- src/transformers/models/mpt/modeling_mpt.py | 10 +- src/transformers/models/mvp/modeling_mvp.py | 4 +- .../models/nllb_moe/modeling_nllb_moe.py | 18 ++-- .../models/prophetnet/modeling_prophetnet.py | 37 ++++--- .../models/rembert/modeling_rembert.py | 2 +- .../models/roberta/modeling_roberta.py | 12 +-- .../modeling_roberta_prelayernorm.py | 18 ++-- .../models/roc_bert/modeling_roc_bert.py | 2 +- .../models/roformer/modeling_roformer.py | 8 +- .../speech_to_text/modeling_speech_to_text.py | 101 ++++++++---------- .../models/speecht5/modeling_speecht5.py | 8 +- .../models/splinter/modeling_splinter.py | 2 +- .../models/trocr/modeling_trocr.py | 6 +- src/transformers/models/xglm/modeling_xglm.py | 26 ++--- .../xlm_roberta/modeling_xlm_roberta.py | 2 +- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 76 +++++++------ src/transformers/models/xmod/modeling_xmod.py | 6 +- 39 files changed, 210 insertions(+), 214 deletions(-) diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 813106d14aec..0a9d52b2fa74 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -948,7 +948,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 023fc68ca10a..73184002b959 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -539,7 +539,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 7cb09c0d95ef..957e59c8120a 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -23,7 +23,7 @@ from torch import nn from torch.nn import functional as F -from ...cache_utils import DynamicCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...generation.logits_process import ( AlternatingCodebooksLogitsProcessor, @@ -439,7 +439,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg if past_key_values is not None: # Omit tokens covered by past_key_values seq_len = input_ids.shape[1] - past_length = past_key_values.get_seq_length() + past_length = past_key_values[0][0].shape[-2] # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: @@ -556,7 +556,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `DynamicCache` instead, e.g. " diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 4b8a0d9c0a17..7d8c38785778 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -256,8 +256,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -547,7 +547,7 @@ def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = BertAttention(config, layer_idx) + self.attention = BertAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: @@ -654,7 +654,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index d6cce5b2520f..b4d8e3df36ee 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -121,8 +121,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -288,7 +288,7 @@ def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = BertGenerationAttention(config, layer_idx) + self.attention = BertGenerationAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: @@ -398,7 +398,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 734c49974d6c..b8a96ada023e 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -26,7 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import EncoderDecoderCache +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -1587,7 +1587,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 29425b127e5b..3a8aa7bcbd97 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -428,7 +428,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 60cec997e500..62b4b3cb4e02 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -792,7 +792,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index bc241dce4438..03bc355a6bac 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -611,7 +611,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 747f3ba7f114..d61493814f8d 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -815,7 +815,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 63736ba2cfad..be2231228867 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -1476,7 +1476,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index e5d4b5de9b01..57f0d7e3478a 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -1085,7 +1085,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `DynamicCache` instead, e.g. " @@ -1363,7 +1363,7 @@ def prepare_inputs_for_generation( token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs if past_key_values: - past_length = past_key_values.get_seq_length() + past_length = past_key_values[0][0].shape[-2] # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: @@ -1383,7 +1383,7 @@ def prepare_inputs_for_generation( # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids[:, -input_ids.shape[1] :].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None diff --git a/src/transformers/models/cpmant/modeling_cpmant.py b/src/transformers/models/cpmant/modeling_cpmant.py index a684cd450aee..4f11bb49654b 100755 --- a/src/transformers/models/cpmant/modeling_cpmant.py +++ b/src/transformers/models/cpmant/modeling_cpmant.py @@ -642,7 +642,7 @@ def forward( span = torch.full((batch, seq_length), 0, dtype=dtype, device=device) return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `DynamicCache` instead, e.g. " diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index 5375e4c2d49c..3bb2b946c54e 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -22,7 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from ...cache_utils import DynamicCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel @@ -343,7 +343,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `DynamicCache` instead, e.g. " diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 8ed35e826258..2d852f2f772b 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -206,8 +206,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -388,7 +388,7 @@ def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = Data2VecTextAttention(config, layer_idx) + self.attention = Data2VecTextAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: @@ -498,7 +498,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index a75c00211edf..377c84e668c9 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -557,7 +557,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 54c0bfa97bc6..a458f00ed98b 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -192,8 +192,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -374,7 +374,7 @@ def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ErnieAttention(config, layer_idx) + self.attention = ErnieAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: @@ -482,7 +482,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 7562d3532049..1e96a127fb3f 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -667,7 +667,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index 32b707845041..86c3d0a65079 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -1084,7 +1084,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -1093,7 +1093,7 @@ def forward( return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 # We don't need img info. when `past_key_values_length` > 0 if past_key_values_length > 0: diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index c8fb1fbbb16e..de605c6f498d 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -479,7 +479,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index ddfd7097c37d..0399b797409a 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -843,7 +843,7 @@ def forward( proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states.view(*proj_shape) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -1890,7 +1890,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -1899,7 +1899,7 @@ def forward( return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index e929b9141f9e..bfc070459e2c 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -639,7 +639,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index dff4553b6a34..e19e33aaab24 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -109,8 +109,7 @@ def forward( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale - - query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2] + query_length = seq_length if past_key_value is None else seq_length + past_key_value.get_seq_length() if position_bias is not None: if len(position_bias.shape) != 3: @@ -376,11 +375,8 @@ def forward( all_hidden_states = () if output_hidden_states else None # Compute alibi tensor: check build_alibi_tensor documentation - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + seq_length_with_past = seq_length + past_key_values_length if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) else: diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 276d7ee5e40a..736b30084d18 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -185,7 +185,7 @@ def forward( proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states.view(*proj_shape) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -883,7 +883,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index b3c83153acc0..6a67973863d4 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -717,6 +717,7 @@ def __init__(self, config: NllbMoeConfig, is_sparse: bool = False, layer_idx: Op dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -779,42 +780,35 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = self.attn_dropout(hidden_states) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.cross_attention_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attention( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = self.attn_dropout(hidden_states) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value += cross_attn_present_key_value - # Fully Connected residual = hidden_states @@ -833,7 +827,7 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states, present_key_value) + outputs = (hidden_states, past_key_value) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 121629e9de0f..c15609493c92 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -533,8 +533,6 @@ def forward( # previous time steps are cached - no need to recompute key and value if they are static query_states = self.query_proj(hidden_states) / (self.head_dim**0.5) - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -553,10 +551,10 @@ def forward( key_states = curr_past_key_value.key_cache[self.layer_idx] value_states = curr_past_key_value.value_cache[self.layer_idx] else: - key_states = self.k_proj(current_states) - value_states = self.v_proj(current_states) - key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.key_proj(current_states) + value_states = self.value_proj(current_states) + key_states = key_states.view(batch_size, -1, self.num_attn_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.num_attn_heads, self.head_dim).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation @@ -568,12 +566,9 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - proj_shape = (batch_size * self.num_heads, -1, self.head_dim) - query_states = query_states.view(batch_size, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states.view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + query_states = query_states.view(batch_size, tgt_len, self.num_attn_heads, self.head_dim).transpose(1, 2) src_len = key_states.size(2) + attn_weights = torch.einsum("bsij,bsjk->bsik", query_states, key_states.transpose(2, 3)) expected_shape = (batch_size, self.num_attn_heads, tgt_len, src_len) if attn_weights.size() != expected_shape: @@ -716,9 +711,9 @@ def forward( value_states = self._shape(value_states, -1, batch_size) proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim) - query_states = query_states.view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + query_states = query_states.reshape(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) # chunk into main stream and predict stream hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1) @@ -731,9 +726,14 @@ def forward( main_key_states, predict_key_states_list = key_states_list[0], key_states_list[1:] main_value_states, predict_value_states_list = value_states_list[0], value_states_list[1:] - # saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim) + # ProphetNet has two separate attention layers, one for self and one for cross attention + # We need to obtain the self attention only for this module, if `EncoderDecoderCache` if past_key_value is not None: - prev_main_key_states, main_value_states = past_key_value.update( + if isinstance(past_key_value, EncoderDecoderCache): + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + main_key_states, main_value_states = curr_past_key_value.update( main_key_states, main_value_states, self.layer_idx, {"cache_position": cache_position} ) @@ -1057,7 +1057,6 @@ def forward( ) hidden_states = self.self_attn_layer_norm(hidden_states + ngram_attention_output) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple cross_attn_weights = None if encoder_hidden_states is not None: # 2nd residual block @@ -1327,7 +1326,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -1336,7 +1335,7 @@ def forward( return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if past_key_values_length != 0: main_relative_position_buckets, predict_relative_position_buckets = None, None diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 3f0ec7f03e0d..91a292fbcbcb 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -518,7 +518,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 1d0b591be9eb..82fcf47e2d2f 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -205,8 +205,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -341,8 +341,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - value_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + value_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -502,7 +502,7 @@ def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = RobertaAttention(config, layer_idx) + self.attention = RobertaAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: @@ -610,7 +610,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 8cc11246f1c8..f7781b88f4a3 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -204,8 +204,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -287,9 +287,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class RobertaPreLayerNormAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() - self.self = RobertaPreLayerNormSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = RobertaPreLayerNormSelfAttention( + config, position_embedding_type=position_embedding_type, layer_idx=layer_idx + ) self.output = RobertaPreLayerNormSelfOutput(config) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pruned_heads = set() @@ -320,8 +322,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: hidden_states_pre_layer_norm = self.LayerNorm(hidden_states) self_outputs = self.self( @@ -332,6 +335,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -374,7 +378,7 @@ def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = RobertaPreLayerNormAttention(config, layer_idx) + self.attention = RobertaPreLayerNormAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: @@ -486,7 +490,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 73e0f1ba54f5..dfb43672971e 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -609,7 +609,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 2a8b77c6dbd5..bbe2acbb6e35 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -24,8 +24,8 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from ...cache_utils import Cache, EncoderDecoderCache, DynamicCache from ...activations import ACT2FN, get_activation +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -375,7 +375,7 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, - cache_position, + cache_position=None, ): self_outputs = self.self( hidden_states, @@ -449,7 +449,7 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, - cache_position, + cache_position=None, ): self_attention_outputs = self.attention( hidden_states, @@ -538,7 +538,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 9c7530d70d5b..3c86e4b5fa9b 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -213,11 +213,12 @@ def __init__( self, embed_dim: int, num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - is_causal: bool = False, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + is_causal: Optional[bool] = False, config: Optional[Speech2TextConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -234,6 +235,7 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -244,14 +246,15 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer @@ -268,42 +271,37 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -424,6 +422,7 @@ def __init__(self, config: Speech2TextConfig, layer_idx=None): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -439,9 +438,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -465,42 +465,35 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -516,7 +509,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -910,7 +903,7 @@ def forward( ) past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 attention_mask = self._update_causal_mask( attention_mask, input_shape, diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index e394137204be..a89efefc25c7 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -24,7 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss from ...activations import ACT2FN -from ...cache_utiils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -940,7 +940,7 @@ def forward( proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states.view(*proj_shape) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -1595,7 +1595,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -1604,7 +1604,7 @@ def forward( return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, hidden_states, past_key_values_length diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index f69eb7ad5f4c..1a7063b2a92d 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -451,7 +451,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 35169d02d296..ebcafbcd3fde 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -232,7 +232,7 @@ def forward( proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states.view(*proj_shape) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -589,7 +589,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -598,7 +598,7 @@ def forward( return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 7df17973d723..8efed99abe1f 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -188,9 +188,9 @@ def forward( proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states.view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + query_states = query_states.reshape(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) @@ -299,9 +299,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -325,42 +326,35 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -376,7 +370,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 2a508ab915c7..07b153f00b9c 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -611,7 +611,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and isinstance(past_key_values, (list, tuple)): + if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index ec5b98013bc8..24e4df58bb0e 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -203,8 +203,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -339,8 +339,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - value_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + value_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -406,11 +406,13 @@ def forward(self, hidden_states, input_tensor): class XLMRobertaXLAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.self = XLMROBERTAXL_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = XLMRobertaXLSelfOutput(config) self.pruned_heads = set() @@ -442,6 +444,7 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): intermediate = self.self_attn_layer_norm(hidden_states) self_outputs = self.self( @@ -452,6 +455,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -486,17 +490,19 @@ def forward(self, hidden_states, input_tensor): class XLMRobertaXLLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = XLMRobertaXLAttention(config) + self.attention = XLMRobertaXLAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = XLMRobertaXLAttention(config, position_embedding_type="absolute") + self.crossattention = XLMRobertaXLAttention( + config, position_embedding_type="absolute", layer_idx=layer_idx + ) self.intermediate = XLMRobertaXLIntermediate(config) self.output = XLMRobertaXLOutput(config) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -510,26 +516,24 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -537,24 +541,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -562,7 +561,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -577,7 +576,7 @@ class XLMRobertaXLEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.layer = nn.ModuleList([XLMRobertaXLLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([XLMRobertaXLLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.gradient_checkpointing = False @@ -593,6 +592,7 @@ def forward( output_attentions=False, output_hidden_states=False, return_dict=True, + cache_position=None, ): if self.gradient_checkpointing and self.training: if use_cache: @@ -600,17 +600,27 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + next_decoder_cache = None - next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -620,8 +630,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -630,13 +641,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -647,12 +659,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -661,7 +677,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index d9746623afaf..aa6d9c3df3d2 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -203,8 +203,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -590,7 +590,7 @@ def forward( hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: From b04ddbc6338566116a2f71b1f278d8ce5736e7fe Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 5 Jun 2025 16:53:29 +0200 Subject: [PATCH 04/58] xmod --- src/transformers/models/xmod/modeling_xmod.py | 70 ++++++++++++------- 1 file changed, 43 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index aa6d9c3df3d2..0423622b8963 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -288,9 +288,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class XmodAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() - self.self = XmodSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = XmodSelfAttention(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) self.output = XmodSelfOutput(config) self.pruned_heads = set() self.pre_norm = config.pre_norm @@ -321,8 +321,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: residual = hidden_states if self.pre_norm: @@ -335,6 +336,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) attention_output = self.output(self_outputs[0], residual) if not self.pre_norm: @@ -428,17 +430,17 @@ def lang_adapter(self, lang_ids: torch.Tensor, hidden_states: torch.Tensor): class XmodLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = XmodAttention(config) + self.attention = XmodAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = XmodAttention(config, position_embedding_type="absolute") + self.crossattention = XmodAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = XmodIntermediate(config) self.output = XmodOutput(config) self.pre_norm = config.pre_norm @@ -451,28 +453,26 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -480,24 +480,19 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, - cross_attn_past_key_value, + past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - residual = attention_output if self.pre_norm: attention_output = self.output.LayerNorm(attention_output) @@ -514,7 +509,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -526,7 +521,7 @@ class XmodEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.layer = nn.ModuleList([XmodLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([XmodLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.is_pre_norm = config.pre_norm if self.is_pre_norm: self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -545,6 +540,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: if self.gradient_checkpointing and self.training: if use_cache: @@ -552,17 +548,27 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -573,8 +579,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -584,8 +591,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] @@ -602,12 +610,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -616,7 +628,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -760,6 +772,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: r""" lang_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -852,6 +865,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None @@ -916,6 +930,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" @@ -963,6 +978,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = outputs[0] From 6a289a7cd7c797e335e63e8a5cd6f7cd9f3736ea Mon Sep 17 00:00:00 2001 From: raushan Date: Fri, 6 Jun 2025 10:57:12 +0200 Subject: [PATCH 05/58] add fmst and clean up some mentions of old cache format --- src/transformers/generation/utils.py | 13 +- .../models/altclip/modeling_altclip.py | 2 +- src/transformers/models/bark/modeling_bark.py | 2 +- .../models/bart/modeling_tf_bart.py | 2 +- src/transformers/models/bert/modeling_bert.py | 2 +- .../modeling_bert_generation.py | 2 +- .../models/big_bird/modeling_big_bird.py | 2 +- .../models/blip/modeling_blip_text.py | 4 +- .../models/blip_2/modeling_blip_2.py | 2 +- .../models/bloom/modeling_bloom.py | 10 +- .../bridgetower/modeling_bridgetower.py | 2 +- src/transformers/models/bros/modeling_bros.py | 2 +- .../models/camembert/modeling_camembert.py | 2 +- .../chinese_clip/modeling_chinese_clip.py | 2 +- src/transformers/models/clap/modeling_clap.py | 2 +- src/transformers/models/clvp/modeling_clvp.py | 2 +- src/transformers/models/ctrl/modeling_ctrl.py | 2 +- .../models/data2vec/modeling_data2vec_text.py | 2 +- .../deprecated/ernie_m/modeling_ernie_m.py | 2 +- .../modeling_gptsan_japanese.py | 2 +- .../models/deprecated/nezha/modeling_nezha.py | 2 +- .../open_llama/modeling_open_llama.py | 4 +- .../deprecated/qdqbert/modeling_qdqbert.py | 4 +- .../models/deprecated/realm/modeling_realm.py | 2 +- .../modeling_speech_to_text_2.py | 4 +- .../xlm_prophetnet/modeling_xlm_prophetnet.py | 2 +- .../models/electra/modeling_electra.py | 2 +- .../models/ernie/modeling_ernie.py | 2 +- src/transformers/models/esm/modeling_esm.py | 2 +- .../models/falcon/modeling_falcon.py | 10 +- src/transformers/models/fsmt/modeling_fsmt.py | 183 ++++++++---------- src/transformers/models/fuyu/modeling_fuyu.py | 2 +- src/transformers/models/git/modeling_git.py | 2 +- src/transformers/models/gpt2/modeling_gpt2.py | 12 +- .../gpt_bigcode/modeling_gpt_bigcode.py | 8 +- .../models/gpt_neo/modeling_gpt_neo.py | 10 +- .../models/imagegpt/modeling_imagegpt.py | 6 +- .../instructblip/modeling_instructblip.py | 2 +- .../modeling_instructblipvideo.py | 2 +- .../megatron_bert/modeling_megatron_bert.py | 2 +- src/transformers/models/mpt/modeling_mpt.py | 10 +- .../models/musicgen/modeling_musicgen.py | 4 +- .../modeling_musicgen_melody.py | 4 +- .../models/prophetnet/modeling_prophetnet.py | 2 +- .../models/rembert/modeling_rembert.py | 2 +- .../models/roberta/modeling_roberta.py | 2 +- .../modeling_roberta_prelayernorm.py | 2 +- .../models/roc_bert/modeling_roc_bert.py | 4 +- .../models/roformer/modeling_roformer.py | 2 +- src/transformers/models/rwkv/modeling_rwkv.py | 4 +- .../models/speecht5/modeling_speecht5.py | 2 +- .../models/splinter/modeling_splinter.py | 2 +- .../xlm_roberta/modeling_xlm_roberta.py | 2 +- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 4 +- src/transformers/models/xmod/modeling_xmod.py | 2 +- tests/generation/test_utils.py | 4 +- 56 files changed, 176 insertions(+), 200 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9af619c82032..8a0b9103c036 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1972,10 +1972,8 @@ def _supports_default_dynamic_cache(self) -> bool: for `HybridMambaAttentionDynamicCache`). """ return ( - self._supports_cache_class - and "jamba" not in self.__class__.__name__.lower() - and "zamba" not in self.__class__.__name__.lower() - and "bamba" not in self.__class__.__name__.lower() + special_model_name not in self.__class__.__name__.lower() + for special_model_name in ["jamba", "zamba", "mamba"] ) def _prepare_cache_for_generation( @@ -3703,12 +3701,9 @@ def _temporary_reorder_cache(self, past_key_values, beam_idx): for this function, with `Cache.reorder_cache` being the sole remaining code path """ model_class = self.__class__.__name__.lower() - # Exception 1: code path for models using the legacy cache format - if isinstance(past_key_values, (tuple, list)): - past_key_values = self._reorder_cache(past_key_values, beam_idx) - # Exception 2: models with different cache formats. These are limited to `DynamicCache` until their + # Exception: models with different cache formats. These are limited to `DynamicCache` until their # cache format is standardized, to avoid adding complexity to the codebase. - elif "gptbigcode" in model_class: + if "gptbigcode" in model_class: if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)): raise ValueError( f"Using an unsupported cache format with {model_class}. Currently, it only supports the " diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 73184002b959..5c983f05de87 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -1231,7 +1231,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 957e59c8120a..cc342b647201 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -439,7 +439,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg if past_key_values is not None: # Omit tokens covered by past_key_values seq_len = input_ids.shape[1] - past_length = past_key_values[0][0].shape[-2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 7ab9817986e6..a772f0fd2346 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -1539,7 +1539,7 @@ def prepare_inputs_for_generation( if decoder_attention_mask is not None: # xla decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] elif past_key_values is not None: # no xla + past_key_values - decoder_position_ids = past_key_values[0][0].shape[2] + decoder_position_ids = past_key_values.get_seq_length() else: # no xla + no past_key_values decoder_position_ids = tf.range(decoder_input_ids.shape[1]) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 7d8c38785778..f754cade4495 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -969,7 +969,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index b4d8e3df36ee..c0487490f74f 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -709,7 +709,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index b8a96ada023e..18e43e7c2bb4 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1951,7 +1951,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 3a8aa7bcbd97..14d9cd9ae6e7 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -753,7 +753,7 @@ def forward( raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds") # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length))).to(device) @@ -953,7 +953,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 5945f4f48ce6..a0ecbf243c27 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -1189,7 +1189,7 @@ def forward( # past_key_values_length past_key_values_length = ( - past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0 + past_key_values.get_seq_length() - self.config.query_length if past_key_values is not None else 0 ) query_length = ( diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index bdba37a73b01..ce67dd208356 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -519,7 +519,7 @@ def forward( ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -887,7 +887,7 @@ def forward( ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1021,7 +1021,7 @@ def forward( ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1157,7 +1157,7 @@ def forward( ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1248,7 +1248,7 @@ def forward( ) -> Union[Tuple, QuestionAnsweringModelOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 62b4b3cb4e02..45661e9f3e1d 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -1115,7 +1115,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index 100337d96b75..cadd1225427a 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -783,7 +783,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(input_shape, device=device) diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 03bc355a6bac..5a96c2234771 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -882,7 +882,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index d61493814f8d..7a8948a8045a 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -1115,7 +1115,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index be2231228867..912cce28b5e8 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -1740,7 +1740,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index 57f0d7e3478a..a5377adb2a08 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -1363,7 +1363,7 @@ def prepare_inputs_for_generation( token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs if past_key_values: - past_length = past_key_values[0][0].shape[-2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index 3bb2b946c54e..78bf0f8c02af 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -575,7 +575,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cac # only last tokens for inputs_ids if past is defined in kwargs if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 2d852f2f772b..757b5cd4ef30 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -706,7 +706,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py b/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py index 28c17afa3f7a..3e55e562d799 100755 --- a/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py +++ b/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py @@ -539,7 +539,7 @@ def forward( past_key_values_length = 0 if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] + past_key_values_length = past_key_values.get_seq_length() # Adapted from paddlenlp.transformers.ernie_m.ErnieMModel if attention_mask is None: diff --git a/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py index 17da733be978..47f6a0191765 100644 --- a/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py @@ -919,7 +919,7 @@ def forward( num_batch = input_ids.shape[0] pasts_or_spout_value = None if past_key_values is not None: - num_pasts_contexts = past_key_values[0][0].shape[2] + num_pasts_contexts = past_key_values.get_seq_length() elif self.config.d_spout and spout is not None: # `spout` is a special input vector specific to GPTSAN # This controls the output by projecting embedded information such as the class of sentences during learning. diff --git a/src/transformers/models/deprecated/nezha/modeling_nezha.py b/src/transformers/models/deprecated/nezha/modeling_nezha.py index 7be52bee5847..bf269e6e6240 100644 --- a/src/transformers/models/deprecated/nezha/modeling_nezha.py +++ b/src/transformers/models/deprecated/nezha/modeling_nezha.py @@ -936,7 +936,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 79d79ea546a9..022567e07b97 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -591,7 +591,7 @@ def forward( use_cache = False if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] + past_key_values_length = past_key_values.get_seq_length() seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: @@ -804,7 +804,7 @@ def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: diff --git a/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py b/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py index 8b68a4e426d4..30107b5dcff5 100755 --- a/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py @@ -924,7 +924,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) @@ -1138,7 +1138,7 @@ def prepare_inputs_for_generation( # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: diff --git a/src/transformers/models/deprecated/realm/modeling_realm.py b/src/transformers/models/deprecated/realm/modeling_realm.py index ac25a177333e..6f3a4df61648 100644 --- a/src/transformers/models/deprecated/realm/modeling_realm.py +++ b/src/transformers/models/deprecated/realm/modeling_realm.py @@ -1058,7 +1058,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py index 6f1dd18d97ff..9bd1a4d271fe 100755 --- a/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py @@ -558,7 +558,7 @@ def forward( raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale @@ -899,7 +899,7 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - past_length = past_key_values[0][0].shape[2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: diff --git a/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py index 17bc9ffada60..82ce291e8d26 100644 --- a/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -596,7 +596,7 @@ def forward(self, inputs_shape, device, attention_mask=None, past_key_values=Non if past_key_values is not None: # position_ids is the same for every token when decoding a single step # Without the int() cast, it doesn't work in some cases when exporting to ONNX - prev_num_input_ids = past_key_values[0][0].shape[2] + prev_num_input_ids = past_key_values.get_seq_length() num_input_ids = inputs_shape[1] + prev_num_input_ids position_ids = torch.ones((1, 1), dtype=torch.long, device=device) * ( int(self.padding_idx + num_input_ids) diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 377c84e668c9..05eb5f9ec850 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -788,7 +788,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(input_shape, device=device) diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index a458f00ed98b..8c1ef579a8e0 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -805,7 +805,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 640b8429f008..7ec40a852871 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -928,7 +928,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index df87d36242e0..2f54b8fcee27 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -742,7 +742,7 @@ def forward( ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1074,7 +1074,7 @@ def forward( ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1196,7 +1196,7 @@ def forward( ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1322,7 +1322,7 @@ def forward( ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1401,7 +1401,7 @@ def forward( ) -> Union[Tuple, QuestionAnsweringModelOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 39f65179b82b..2cba9db9b592 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -28,13 +28,14 @@ """PyTorch Fairseq model, ported from https://github.com/pytorch/fairseq/tree/master/examples/wmt19""" import math -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import torch from torch import Tensor, nn from torch.nn import CrossEntropyLoss, LayerNorm from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import ( @@ -452,7 +453,7 @@ def forward( class DecoderLayer(nn.Module): - def __init__(self, config: FSMTConfig): + def __init__(self, config: FSMTConfig, layer_idx=None): super().__init__() self.embed_dim = config.d_model @@ -460,6 +461,7 @@ def __init__(self, config: FSMTConfig): embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -471,6 +473,7 @@ def __init__(self, config: FSMTConfig): config.decoder_attention_heads, dropout=config.attention_dropout, encoder_decoder_attention=True, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -488,12 +491,10 @@ def forward( cross_attn_layer_head_mask=None, decoder_padding_mask=None, output_attentions=False, + cache_position=None, ): residual = x - if layer_state is None: - layer_state = {} - # Self Attention x, self_attn_weights = self.self_attn( query=x, @@ -503,6 +504,7 @@ def forward( attn_mask=causal_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = residual + x @@ -518,6 +520,7 @@ def forward( layer_state=layer_state, # mutates layer state layer_head_mask=cross_attn_layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = residual + x @@ -559,7 +562,7 @@ def __init__(self, config: FSMTConfig, embed_tokens: nn.Embedding): self.embed_positions = SinusoidalPositionalEmbedding( config.max_position_embeddings + self.padding_idx + 1, embed_dim, self.padding_idx ) - self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.decoder_layers)]) # type: List[DecoderLayer] + self.layers = nn.ModuleList([DecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) # type: List[DecoderLayer] if is_deepspeed_zero3_enabled(): import deepspeed @@ -585,10 +588,11 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, - use_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ): """ Includes several features from "Jointly Learning to Align and Translate with Transformer Models" (Garg et al., @@ -645,6 +649,17 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + x += positions x = nn.functional.dropout(x, p=self.dropout, training=self.training) @@ -656,7 +671,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attns = () if output_attentions else None - next_decoder_cache = [] + next_decoder_cache = None # check if head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -676,22 +691,21 @@ def forward( if dropout_probability < self.layerdrop: continue - layer_state = past_key_values[idx] if past_key_values is not None else None - x, layer_self_attn, layer_past, layer_cross_attn = decoder_layer( x, encoder_hidden_states, encoder_attn_mask=encoder_padding_mask, decoder_padding_mask=decoder_padding_mask, - layer_state=layer_state, + layer_state=past_key_values, causal_mask=decoder_causal_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), output_attentions=output_attentions, + cache_position=cache_position, ) if use_cache: - next_decoder_cache.append(layer_past.copy()) + next_decoder_cache = layer_past if output_attentions: all_self_attns += (layer_self_attn,) @@ -710,6 +724,8 @@ def forward( x = self.output_projection(x) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() if not return_dict: return tuple( @@ -741,6 +757,7 @@ def __init__( dropout=0.0, bias=True, encoder_decoder_attention=False, # otherwise self_attention + layer_idx=None, ): super().__init__() self.embed_dim = embed_dim @@ -749,6 +766,7 @@ def __init__( self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" self.scaling = self.head_dim**-0.5 + self.layer_idx = layer_idx self.encoder_decoder_attention = encoder_decoder_attention self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -757,64 +775,65 @@ def __init__( self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self" - def _shape(self, tensor, seq_len, bsz): - return tensor.contiguous().view(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) - def forward( self, query, key: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, - layer_state: Optional[Dict[str, Optional[Tensor]]] = None, + layer_state: Optional[Cache] = None, attn_mask: Optional[Tensor] = None, layer_head_mask: Optional[Tensor] = None, - output_attentions=False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: """Input shape: Time(SeqLen) x Batch x Channel""" - static_kv: bool = self.encoder_decoder_attention tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] - # get here for encoder decoder cause of static_kv - if layer_state is not None: # reuse k,v and encoder_padding_mask - saved_state = layer_state.get(self.cache_key, {}) - if "prev_key" in saved_state and static_kv: - # previous time steps are cached - no need to recompute key and value if they are static - key = None - else: - saved_state = None - layer_state = {} - q = self.q_proj(query) * self.scaling - if static_kv: - if key is None: - k = v = None + if layer_state is not None: + if isinstance(layer_state, EncoderDecoderCache): + is_updated = layer_state.is_updated.get(self.layer_idx) + if self.encoder_decoder_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = layer_state.cross_attention_cache + else: + curr_past_key_value = layer_state.self_attention_cache else: - k = self.k_proj(key) - v = self.v_proj(key) + curr_past_key_value = layer_state + + current_states = key if self.encoder_decoder_attention else query + if self.encoder_decoder_attention and layer_state is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - k = self.k_proj(query) - v = self.v_proj(query) - - q = self._shape(q, tgt_len, bsz) - if k is not None: - k = self._shape(k, -1, bsz) - if v is not None: - v = self._shape(v, -1, bsz) - - if saved_state is not None: - k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz) - - # Update cache - layer_state[self.cache_key] = { - "prev_key": k.view(bsz, self.num_heads, -1, self.head_dim), - "prev_value": v.view(bsz, self.num_heads, -1, self.head_dim), - "prev_key_padding_mask": key_padding_mask if not static_kv else None, - } + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if layer_state is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not self.encoder_decoder_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if self.encoder_decoder_attention: + layer_state.is_updated[self.layer_idx] = True + + query_states = self.q_proj(query) * self.scaling - assert k is not None - src_len = k.size(1) - attn_weights = torch.bmm(q, k.transpose(1, 2)) + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.reshape(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + assert key_states is not None + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len) if attn_mask is not None: @@ -857,45 +876,14 @@ def forward( training=self.training, ) - assert v is not None - attn_output = torch.bmm(attn_probs, v) + assert value_states is not None + attn_output = torch.bmm(attn_probs, value_states) assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights_reshaped - def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz): - # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) - if "prev_key" in saved_state: - _prev_key = saved_state["prev_key"] - assert _prev_key is not None - prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) - if static_kv: - k = prev_key - else: - assert k is not None - k = torch.cat([prev_key, k], dim=1) - if "prev_value" in saved_state: - _prev_value = saved_state["prev_value"] - assert _prev_value is not None - prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) - if static_kv: - v = prev_value - else: - assert v is not None - v = torch.cat([prev_value, v], dim=1) - assert k is not None and v is not None - prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None) - if prev_key_padding_mask is not None: - if static_kv: - new_key_padding_mask = prev_key_padding_mask - else: - new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1) - else: - new_key_padding_mask = key_padding_mask - return k, v, new_key_padding_mask - def fill_with_neg_inf(t): """FP16-compatible function that fills a input_ids with -inf.""" @@ -953,6 +941,7 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1033,6 +1022,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1098,6 +1088,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1161,6 +1152,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = outputs[0] @@ -1189,17 +1181,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = [] - for layer_past in past_key_values: - # get the correct batch idx from decoder layer's batch dim for cross and self-attn - layer_past_new = { - attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() - } - reordered_past.append(layer_past_new) - return reordered_past - def get_encoder(self): return self.model.encoder diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index 78b7ae7d4db9..c5e998114fb0 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -188,7 +188,7 @@ def forward( past_key_values_length = 0 if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] + past_key_values_length = past_key_values.get_seq_length() seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index c63b00fe6b3a..7e1399f9b479 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -1151,7 +1151,7 @@ def forward( past_key_values_length = 0 if past_key_values is not None: past_key_values_length = ( - past_key_values[0][0].shape[2] + past_key_values.get_seq_length() if not isinstance(past_key_values, Cache) else past_key_values.get_seq_length() ) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 4b3853b43691..3e66ede9f1dc 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -793,7 +793,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1190,7 +1190,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1343,7 +1343,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1506,7 +1506,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1639,7 +1639,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1725,7 +1725,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 7f40aabefb5b..225c05acb349 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -736,7 +736,7 @@ def forward( r""" input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1056,7 +1056,7 @@ def forward( r""" input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1169,7 +1169,7 @@ def forward( r""" input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1300,7 +1300,7 @@ def forward( r""" input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 95de9e82d5ec..3937d106e42c 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -550,7 +550,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -852,7 +852,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -976,7 +976,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1099,7 +1099,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1181,7 +1181,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 1e96a127fb3f..aa8201489531 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -598,7 +598,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -867,7 +867,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1023,7 +1023,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 8018dbe76a95..06e92a170178 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -1102,7 +1102,7 @@ def forward( # past_key_values_length past_key_values_length = ( - past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0 + past_key_values.get_seq_length() - self.config.query_length if past_key_values is not None else 0 ) query_length = query_embeds.shape[1] if query_embeds is not None else 0 diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index cc18bbf90b63..6f4d599cada0 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -1068,7 +1068,7 @@ def forward( # past_key_values_length past_key_values_length = ( - past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0 + past_key_values.get_seq_length() - self.config.query_length if past_key_values is not None else 0 ) query_length = query_embeds.shape[1] if query_embeds is not None else 0 diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index eeafbbf5e5b8..620450abfb25 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -835,7 +835,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index e19e33aaab24..fe2d5d142299 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -321,7 +321,7 @@ def forward( ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -485,7 +485,7 @@ def forward( ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -604,7 +604,7 @@ def forward( ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -728,7 +728,7 @@ def forward( ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -806,7 +806,7 @@ def forward( ) -> Union[Tuple, QuestionAnsweringModelOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 42b05671333c..295be0da2d3d 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -563,7 +563,7 @@ def forward( raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if inputs_embeds is None: inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) @@ -1909,7 +1909,7 @@ def prepare_inputs_for_generation( decoder_attention_mask = decoder_attention_mask.repeat((2, 1)) if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if decoder_input_ids.shape[1] > past_length: diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 4e1ea39e754e..92c3a8372b29 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -532,7 +532,7 @@ def forward( raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if inputs_embeds is None: inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) @@ -1813,7 +1813,7 @@ def prepare_inputs_for_generation( decoder_attention_mask = decoder_attention_mask.repeat((2, 1)) if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if decoder_input_ids.shape[1] > past_length: diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index c15609493c92..583c1ba48341 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -463,7 +463,7 @@ def forward(self, inputs_shape, device, attention_mask=None, past_key_values=Non if past_key_values is not None: # position_ids is the same for every token when decoding a single step # Without the int() cast, it doesn't work in some cases when exporting to ONNX - prev_num_input_ids = past_key_values[0][0].shape[2] + prev_num_input_ids = past_key_values.get_seq_length() num_input_ids = inputs_shape[1] + prev_num_input_ids position_ids = torch.ones((1, 1), dtype=torch.long, device=device) * ( int(self.padding_idx + num_input_ids) diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 91a292fbcbcb..a38bbbc767c8 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -752,7 +752,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 82fcf47e2d2f..298e342b39dc 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -822,7 +822,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index f7781b88f4a3..88f23c2c9f6e 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -709,7 +709,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index dfb43672971e..1f15223b4d45 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -900,7 +900,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) @@ -1487,7 +1487,7 @@ def prepare_inputs_for_generation( # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index bbe2acbb6e35..21b8c42d6e21 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -888,7 +888,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index 6e71308e17ac..8a0a8309c424 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -545,7 +545,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -744,7 +744,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index a89efefc25c7..7d522972410f 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -823,7 +823,7 @@ def forward( else: raise ValueError("You have to specify `decoder_input_ids`") - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 positions = self.embed_positions(input_ids, past_key_values_length) inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 1a7063b2a92d..16776b0a3719 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -639,7 +639,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 07b153f00b9c..9a12e56c517b 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -813,7 +813,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 24e4df58bb0e..b2fbb4dc03e5 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -808,7 +808,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): @@ -1039,7 +1039,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 0423622b8963..1263b805977e 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -804,7 +804,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if lang_ids is None: if self.config.default_language is None: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 617ba23ebd6b..df045a35c797 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -545,8 +545,8 @@ def test_greedy_generate_dict_outputs_use_cache(self): if self.has_attentions: config._attn_implementation = "eager" # can't output attentions otherwise - if not hasattr(config.get_text_config(), "use_cache"): - self.skipTest(reason=f"{model_class.__name__} doesn't support caching") + # if not hasattr(config.get_text_config(), "use_cache"): + # self.skipTest(reason=f"{model_class.__name__} doesn't support caching") if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes") From b3be72b46c428f56c0cff31e2ab6281609e50774 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 9 Jun 2025 12:15:19 +0200 Subject: [PATCH 06/58] gpt-bigcode now follows standards --- src/transformers/models/bert/modeling_bert.py | 4 +- .../gpt_bigcode/configuration_gpt_bigcode.py | 1 + .../gpt_bigcode/modeling_gpt_bigcode.py | 718 ++++-------------- .../models/prophetnet/modeling_prophetnet.py | 14 +- tests/generation/test_utils.py | 4 +- 5 files changed, 174 insertions(+), 567 deletions(-) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index f754cade4495..1c9317d844d0 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -391,8 +391,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - value_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 46a3dfea4410..127a0eed4732 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -134,6 +134,7 @@ def __init__( self.attention_softmax_in_fp32 = attention_softmax_in_fp32 self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32 self.multi_query = multi_query + self.num_key_value_heads = 1 if multi_query else n_head self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 225c05acb349..05e59ac1ce0b 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -14,7 +14,7 @@ """PyTorch GPTBigCode model.""" import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -22,26 +22,27 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import is_flash_attn_available from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import is_torch_greater_or_equal_than_2_2 +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ( auto_docstring, + can_return_tuple, logging, ) from .configuration_gpt_bigcode import GPTBigCodeConfig if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward + pass logger = logging.get_logger(__name__) @@ -77,6 +78,49 @@ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor return x +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class GPTBigCodeAttention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() @@ -89,6 +133,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.head_dim = self.embed_dim // self.num_heads self.kv_heads = 1 if self.multi_query else self.num_heads self.kv_dim = self.kv_heads * self.head_dim + self.num_key_value_groups = self.num_heads // self.kv_heads self.split_size = self.embed_dim self.is_causal = True @@ -99,6 +144,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): ) self.scale_attn_weights = config.scale_attn_weights + self.scaling = self.head_dim**0.5 if config.scale_attn_weights else 1.0 self.is_cross_attention = is_cross_attention self.layer_idx = layer_idx @@ -119,413 +165,93 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.c_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.attn_dropout = config.attn_pdrop self.resid_dropout = nn.Dropout(config.resid_pdrop) - def _get_mask_value(self, device, dtype): - # torch.where expects a tensor. We use a cache to avoid recreating it every time. - if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: - self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) - return self.mask_value - - def _attn(self, query, key, value, attention_mask=None, head_mask=None): - dtype = query.dtype - softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype - upcast = dtype != softmax_dtype - - unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1 - scale_factor = unscale**-1 - if self.scale_attn_weights: - scale_factor /= self.head_dim**0.5 - - # MQA models: (batch_size, query_length, num_heads * head_dim) - # MHA models: (batch_size, num_heads, query_length, head_dim) - query_shape = query.shape - batch_size = query_shape[0] - key_length = key.size(-1) - if self.multi_query: - # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length) - # -> (batch_size, query_length, num_heads, key_length) - query_length = query_shape[1] - attn_shape = (batch_size, query_length, self.num_heads, key_length) - attn_view = (batch_size, query_length * self.num_heads, key_length) - # No copy needed for MQA 2, or when layer_past is provided. - query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) - else: - # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length) - # -> (batch_size, num_heads, query_length, key_length) - query_length = query_shape[2] - attn_shape = (batch_size, self.num_heads, query_length, key_length) - attn_view = (batch_size * self.num_heads, query_length, key_length) - # Always copies - query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim) - # No copy when layer_past is provided. - key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length) - - attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype) - if query.device.type == "cpu": - # This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588. - # The bug was fixed in https://github.com/pytorch/pytorch/pull/96086, - # but the fix has not been released as of pytorch version 2.0.0. - attn_weights = torch.zeros_like(attn_weights) - beta = 1 - else: - beta = 0 - attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape) - - if upcast: - # Use a fused kernel to prevent a large overhead from casting and scaling. - # Sub-optimal when the key length is not a multiple of 8. - if attention_mask is None: - attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype) - else: - mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) - attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype) - else: - if attention_mask is not None: - mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) - - # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion. - attn_weights = torch.where(attention_mask, attn_weights, mask_value) - - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - - attn_weights = self.attn_dropout(attn_weights) - - # Mask heads if we want to - if head_mask is not None: - if self.multi_query: - head_mask = head_mask.transpose(1, 2) - attn_weights = attn_weights * head_mask - - if self.multi_query: - attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) - else: - attn_output = torch.matmul(attn_weights, value) - - return attn_output, attn_weights - def forward( self, hidden_states: torch.Tensor, - layer_past: Optional[torch.Tensor] = None, + layer_past: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: - if encoder_hidden_states is not None: - if not hasattr(self, "q_attn") or not self.is_cross_attention: - raise ValueError( - "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." - ) - - query = self.q_attn(hidden_states) - key_value = self.c_attn(encoder_hidden_states) - attention_mask = encoder_attention_mask - elif self.multi_query: - query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) - else: - # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), - # i.e., the memory layout is not the same as GPT2. - # This makes the concatenation with past_key_value more efficient. - query, key_value = ( - self.c_attn(hidden_states) - .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) - .transpose(1, 2) - .split((self.head_dim, 2 * self.head_dim), dim=3) - ) + input_shape = hidden_states.shape[:-1] if layer_past is not None: - key_value = torch.cat((layer_past, key_value), dim=-2) - present = key_value if use_cache else None - - key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) - - attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) - - if not self.multi_query: - attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) - attn_output = self.c_proj(attn_output) - attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, present) - if output_attentions: - if self.multi_query: - # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) - attn_weights = attn_weights.transpose(1, 2) - outputs += (attn_weights,) - - return outputs # a, present, (attentions) - - -class GPTBigCodeFlashAttention2(GPTBigCodeAttention): - """ - GPTBigCode flash attention module. This module inherits from `GPTBigCodeAttention` as the weights of the module - stays untouched. The only required change would be on the forward pass where it needs to correctly call the public - API of flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + if isinstance(layer_past, EncoderDecoderCache): + is_updated = layer_past.is_updated.get(self.layer_idx) + if self.is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = layer_past.cross_attention_cache + else: + curr_past_key_value = layer_past.self_attention_cache + else: + curr_past_key_value = layer_past - def forward( - self, - hidden_states: torch.Tensor, - layer_past: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ) -> Union[ - Tuple[torch.Tensor, Optional[torch.Tensor]], - Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], - ]: - if encoder_hidden_states is not None: + if self.is_cross_attention: if not hasattr(self, "q_attn") or not self.is_cross_attention: raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." ) - - query = self.q_attn(hidden_states) - key_value = self.c_attn(encoder_hidden_states) - attention_mask = encoder_attention_mask - elif self.multi_query: - query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) - else: - # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), - # i.e., the memory layout is not the same as GPT2. - # This makes the concatenation with past_key_value more efficient. - query, key_value = ( - self.c_attn(hidden_states) - .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) - .transpose(1, 2) - .split((self.head_dim, 2 * self.head_dim), dim=3) - ) - - if layer_past is not None: - key_value = torch.cat((layer_past, key_value), dim=-2) - present = key_value if use_cache else None - - key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - if self.multi_query: - batch_size, query_length, _ = query.shape - query = query.reshape(batch_size, query_length, self.num_heads, self.head_dim) - key = key.unsqueeze(2) - value = value.unsqueeze(2) + if layer_past is not None and is_updated: + # reuse k,v, cross_attentions + key = curr_past_key_value.key_cache[self.layer_idx] + value = curr_past_key_value.value_cache[self.layer_idx] + else: + query = self.q_attn(hidden_states).view(*input_shape, -1, self.head_dim).transpose(1, 2) + key, value = self.c_attn(encoder_hidden_states).split((self.head_dim, self.head_dim), dim=-1) else: - query_length = query.shape[2] - batch_size, _, tgt, _ = key.shape - query = query.transpose(1, 2).reshape(batch_size, query_length, self.num_heads, self.head_dim) - key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) - value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) - - attn_dropout = self.attn_pdrop if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + if self.multi_query: + query, key, value = ( + self.c_attn(hidden_states).unsqueeze(1).split((self.embed_dim, self.kv_dim, self.kv_dim), dim=3) + ) + query = query.view(*input_shape, -1, self.head_dim).transpose(1, 2) else: - target_dtype = self.c_attn.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - query = query.to(target_dtype) - key = key.to(target_dtype) - value = value.to(target_dtype) + query, key, value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split(3 * [self.head_dim], dim=3) + ) - attn_output = _flash_attention_forward( + if layer_past is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not self.is_cross_attention else None + key, value = curr_past_key_value.update(key, value, self.layer_idx, {"cache_position": cache_position}) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if self.is_cross_attention: + layer_past.is_updated[self.layer_idx] = True + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, query, key, value, attention_mask, - query_length, - dropout=attn_dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - - attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) - attn_output = self.c_proj(attn_weights_reshaped) - attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, present) - - if output_attentions: - if self.multi_query: - # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) - attn_weights_reshaped = attn_weights_reshaped.transpose(1, 2) - else: - attn_weights_reshaped = None - - outputs += (attn_weights_reshaped,) - - return outputs # a, present, (attentions) - - -class GPTBigCodeSdpaAttention(GPTBigCodeAttention): - def _attn(self, query, key, value, attention_mask=None): - scale = None - if not self.scale_attn_weights: - scale = 1 - - # MQA models: (batch_size, query_length, num_heads * head_dim) - # MHA models: (batch_size, num_heads, query_length, head_dim) - query_shape = query.shape - batch_size = query_shape[0] - key.shape[-2] - - if self.multi_query: - query_length = query_shape[1] - - # SDPA requires the dimension [..., sequence_length, head_dim]. - query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2) - - # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions. - key = key.unsqueeze(1) - value = value.unsqueeze(1) - - # Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend - # and flash attention backend (No available kernel. Aborting execution.) from the shapes - # query = [batch_size, num_heads, query_length, head_dim] - # key = [batch_size, 1, past_length, head_dim] - # value = [batch_size, 1, past_length, head_dim] - # - # torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check. - if is_torch_greater_or_equal_than_2_2: - key = key.expand(-1, self.num_heads, -1, -1) - value = value.expand(-1, self.num_heads, -1, -1) - else: - query_length = query_shape[-1] - - # See the comment above. - if query.device.type == "cuda" and attention_mask is not None: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not - # create a causal mask in case query_length == 1. - is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False - - sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=self.attn_pdrop if self.training else 0.0, - is_causal=is_causal, - scale=scale, + dropout=0.0 if not self.training else self.attn_dropout, + scaling=self.scaling, + head_mask=head_mask, + **kwargs, ) - if self.multi_query: - # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim) - sdpa_result = sdpa_result.transpose(1, 2) - - # Reshape is kind of expensive here, as it does a memory copy, - # but I did not manage to make away without it (logits do not match when using view) - # (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim) - sdpa_result = sdpa_result.reshape(query_shape) - - return sdpa_result, None - - def forward( - self, - hidden_states: torch.Tensor, - layer_past: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ) -> Union[ - Tuple[torch.Tensor, Optional[torch.Tensor]], - Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], - ]: - if encoder_hidden_states is not None: - if not hasattr(self, "q_attn") or not self.is_cross_attention: - raise ValueError( - "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." - ) - - query = self.q_attn(hidden_states) - key_value = self.c_attn(encoder_hidden_states) - attention_mask = encoder_attention_mask - elif self.multi_query: - query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) - else: - # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), - # i.e., the memory layout is not the same as GPT2. - # This makes the concatenation with past_key_value more efficient. - query, key_value = ( - self.c_attn(hidden_states) - .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) - .transpose(1, 2) - .split((self.head_dim, 2 * self.head_dim), dim=3) - ) - - if layer_past is not None: - key_value = torch.cat((layer_past, key_value), dim=-2) - present = key_value if use_cache else None - - key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) - - if not output_attentions: - # Difference with the original implementation: there is no need to transpose the key here, - # as SDPA expects seq_length to be at index -2 for the key as well - attn_output, attn_weights = self._attn(query, key, value, attention_mask) - else: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "GPTBigCodeModel is using GPTBigCodeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`." - ' Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2), value, attention_mask) - - if not self.multi_query: - attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, present) - if output_attentions: - if self.multi_query: - # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) - attn_weights = attn_weights.transpose(1, 2) - outputs += (attn_weights,) - - return outputs + return attn_output, layer_past, attn_weights class GPTBigCodeMLP(nn.Module): @@ -546,13 +272,6 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl return hidden_states -GPTBIGCODE_ATTENTION_CLASSES = { - "eager": GPTBigCodeAttention, - "flash_attention_2": GPTBigCodeFlashAttention2, - "sdpa": GPTBigCodeSdpaAttention, -} - - class GPTBigCodeBlock(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() @@ -561,7 +280,7 @@ def __init__(self, config, layer_idx=None): self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) @@ -569,9 +288,7 @@ def __init__(self, config, layer_idx=None): if config.multi_query: raise NotImplementedError("Cross-attention not implemented for MQA") - self.crossattention = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation]( - config, is_cross_attention=True, layer_idx=layer_idx - ) + self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) @@ -580,13 +297,14 @@ def __init__(self, config, layer_idx=None): def forward( self, hidden_states: Optional[Tuple[torch.Tensor]], - layer_past: Optional[torch.Tensor] = None, + layer_past: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[ Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] @@ -600,6 +318,8 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, + **kwargs, ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] @@ -622,6 +342,8 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, + cache_position=cache_position, + **kwargs, ) attn_output = cross_attn_outputs[0] # residual connection @@ -716,6 +438,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, new_embeddings): self.wte = new_embeddings + @can_return_tuple @auto_docstring def forward( self, @@ -732,6 +455,8 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`): @@ -754,10 +479,9 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") elif input_ids is not None: - self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) batch_size = input_ids.shape[0] @@ -770,81 +494,43 @@ def forward( if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") - device = input_ids.device if input_ids is not None else inputs_embeds.device + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0].size(-2) - - if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_length > 0: - position_ids = position_ids[:, past_length : input_shape[-1] + past_length :] - elif position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) - - # Self-attention mask. - query_length = input_shape[-1] - key_length = past_length + query_length - self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + ) if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None encoder_attention_mask = ( encoder_attention_mask.bool() if (encoder_attention_mask is not None and 0 in encoder_attention_mask) else None ) else: - # 4d mask is passed through the layers - if attention_mask is not None: - self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( - dtype=torch.bool, device=self_attention_mask.device - ) - - # MQA models: (batch_size, query_length, n_heads, key_length) - # MHA models: (batch_size, n_heads, query_length, key_length) - self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) - - if self._use_sdpa and head_mask is None and not output_attentions: - # SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer. - dtype = self.wte.weight.dtype - min_dtype = torch.finfo(dtype).min - self_attention_mask = torch.where( - self_attention_mask, - torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device), - torch.full([], min_dtype, dtype=dtype, device=self_attention_mask.device), - ) - - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - if self.multi_query: - # gpt_bigcode using MQA has the bad taste to use a causal mask with shape - # [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose. - self_attention_mask = self_attention_mask.transpose(1, 2) - - if ( - query_length > 1 - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - ): - # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend - # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 - self_attention_mask = AttentionMaskConverter._unmask_unattended( - self_attention_mask, min_dtype=min_dtype - ) - - attention_mask = self_attention_mask - # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if ( @@ -865,24 +551,22 @@ def forward( # head_mask has shape n_layer x batch x n_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device) if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) token_type_embeds = self.wte(token_type_ids) hidden_states = hidden_states + token_type_embeds hidden_states = self.drop(hidden_states) - output_shape = input_shape + (hidden_states.size(-1),) - presents = [] if use_cache else None + next_decoder_cache = None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + for i, block in enumerate(self.h): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -891,28 +575,31 @@ def forward( block.__call__, hidden_states, None, - attention_mask, + causal_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, + cache_position, ) else: outputs = block( hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, + layer_past=past_key_values, + attention_mask=causal_mask, head_mask=head_mask[i], encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] if use_cache: - presents.append(outputs[1]) + next_decoder_cache = outputs[1] if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) @@ -926,16 +613,13 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if not return_dict: - return tuple( - v - for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] - if v is not None - ) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -965,75 +649,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): - # Overwritten -- `past_key_values` with uncommon shape - - token_type_ids = kwargs.get("token_type_ids", None) - # Omit tokens covered by past_key_values - if past_key_values: - if self.config.multi_query: - past_length = past_key_values[0].shape[1] - else: - past_length = past_key_values[0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -input_ids.shape[1] :] - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - else: - position_ids = None - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - } - ) - return model_inputs - - def _get_initial_cache_position(self, seq_length, device, model_kwargs): - """ - Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length. - Since gpt bigcode is special, the method is overridden here, other models use it from `generation.utils.py`. - """ - past_length = 0 - if "past_key_values" in model_kwargs: - if self.config.multi_query: - past_length = model_kwargs["past_key_values"][0].shape[1] - else: - past_length = model_kwargs["past_key_values"][0].shape[2] - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] - else: - cur_len = seq_length - model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=device) - return model_kwargs - @auto_docstring def forward( self, @@ -1051,6 +666,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" @@ -1087,6 +703,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = transformer_outputs[0] @@ -1114,17 +731,6 @@ def forward( cross_attentions=transformer_outputs.cross_attentions, ) - @staticmethod - def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values) - @auto_docstring( custom_intro=""" @@ -1165,6 +771,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`): @@ -1198,6 +805,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + **kwargs, ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 583c1ba48341..371af3e2a8dd 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -460,7 +460,7 @@ def forward(self, inputs_shape, device, attention_mask=None, past_key_values=Non ) if position_ids is None: - if past_key_values is not None: + if past_key_values is not None and past_key_values.get_seq_length() != 0: # position_ids is the same for every token when decoding a single step # Without the int() cast, it doesn't work in some cases when exporting to ONNX prev_num_input_ids = past_key_values.get_seq_length() @@ -1312,12 +1312,6 @@ def forward( batch_size, sequence_length = inputs_embeds.shape[:2] - main_stream_pos_embed, position_ids = self.position_embeddings( - (batch_size, sequence_length), - device=inputs_embeds.device, - past_key_values=past_key_values, - ) - if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( @@ -1337,6 +1331,12 @@ def forward( past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + main_stream_pos_embed, position_ids = self.position_embeddings( + (batch_size, sequence_length), + device=inputs_embeds.device, + past_key_values=past_key_values, + ) + if past_key_values_length != 0: main_relative_position_buckets, predict_relative_position_buckets = None, None else: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index df045a35c797..91de2cd8a6d6 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2508,13 +2508,12 @@ def _check_generate_outputs(self, output, config, use_cache=False, num_return_se # Past Key Value States -- a few notes here: # 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1" # 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the - # standard cache format (e.g.gptbigcode ) + # standard cache format (e.g.mamba architecture ) models_without_standard_cache = ( "bamba", "ctrl", "fsmt", "granitemoehybrid", - "gptbigcode", "mega", "reformer", "jamba", @@ -5154,7 +5153,6 @@ def assert_no_sklearn(self): @parameterized.expand([(is_sklearn_available(),), (False,)]) def test_update_candidate_strategy_no_matches_short(self, sklearn_available): - print("test_update_candidate_strategy_no_matches_short") self.original_matches = [] self.candidate_generator.matches = self.original_matches self.num_matches = 0 From 85061bc1510c5f9ff78cf644e99247997ebc2eb9 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 9 Jun 2025 12:15:25 +0200 Subject: [PATCH 07/58] delete tuple cache reference in generation --- .../generation/candidate_generator.py | 46 +------------- src/transformers/generation/utils.py | 61 +++---------------- 2 files changed, 8 insertions(+), 99 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index bb9222030553..6e8c74ad5a73 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -295,9 +295,7 @@ def _update_past_and_masks( has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None if has_past_key_values: new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv - self.assistant_kwargs["past_key_values"] = _crop_past_key_values( - self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - num_added_tokens - ) + self.assistant_kwargs["past_key_values"].crop(new_cache_size - num_added_tokens) self.assistant_kwargs = _prepare_attention_mask( self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder ) @@ -1179,48 +1177,6 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, base_model.config.num_hidden_layers = original_num_hidden_layers return candidate_ids, candidate_logits - -def _crop_past_key_values(model, past_key_values, max_length): - """Crops the past key values up to a certain maximum length.""" - new_past = [] - if isinstance(past_key_values, Cache): - past_key_values.crop(max_length) - elif model.config.is_encoder_decoder: - for idx in range(len(past_key_values)): - new_past.append( - ( - past_key_values[idx][0][:, :, :max_length, :], - past_key_values[idx][1][:, :, :max_length, :], - past_key_values[idx][2], - past_key_values[idx][3], - ) - ) - past_key_values = tuple(new_past) - # gptbigcode is special and stores kv in shape (batch_size, seq_len, dim), if it's a multi_query model - elif "gptbigcode" in model.__class__.__name__.lower() or ( - model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower() - ): - if model.config.multi_query: - for idx in range(len(past_key_values)): - past_key_values[idx] = past_key_values[idx][:, :max_length, :] - else: - for idx in range(len(past_key_values)): - past_key_values[idx] = past_key_values[idx][:, :, :max_length, :] - elif past_key_values is not None: - for idx in range(len(past_key_values)): - if past_key_values[idx] != ([], []): - new_past.append( - ( - past_key_values[idx][0][:, :, :max_length, :], - past_key_values[idx][1][:, :, :max_length, :], - ) - ) - else: - new_past.append((past_key_values[idx][0], past_key_values[idx][1])) - past_key_values = tuple(new_past) - return past_key_values - - def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]: """Expands or crops the model's mask for decoding purposes, to the defined length""" diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8a0b9103c036..ededb2eb2a75 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -68,7 +68,6 @@ EarlyExitCandidateGenerator, PromptLookupCandidateGenerator, UniversalSpeculativeDecodingGenerator, - _crop_past_key_values, _prepare_attention_mask, _prepare_token_type_ids, ) @@ -567,15 +566,7 @@ def prepare_inputs_for_generation( # 1. Handle BC: model_inputs = {} - # - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`) - if self._supports_cache_class: - model_inputs["cache_position"] = cache_position - # - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this - # function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly - # (this alternative is not as robust as calling `generate` and letting it create `cache_position`) - elif cache_position is None: - past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + model_inputs["cache_position"] = cache_position # 2. Generic cache-dependent input preparation if past_key_values is not None: @@ -1553,14 +1544,7 @@ def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer): ) def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): - """Validates model kwargs for generation. Generate argument typos will also be caught here.""" - # If a `Cache` instance is passed, checks whether the model is compatible with it - if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class: - raise ValueError( - f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please " - "check the model documentation for supported cache formats." - ) - + """Validates model kwargs for generation. Generate argument typos will also be caught here.""" # Excludes arguments that are handled before calling any model function if self.config.is_encoder_decoder: for key in ["decoder_input_ids"]: @@ -3692,30 +3676,6 @@ def _sample( else: return input_ids - # Auxiliary functions for beam search - def _temporary_reorder_cache(self, past_key_values, beam_idx): - """ - Temporary function to handle the different types of cache reordering processes while we roll out `Cache`. - - TODO: standardize cache formats and make all models compatible with `Cache`. It would remove the need - for this function, with `Cache.reorder_cache` being the sole remaining code path - """ - model_class = self.__class__.__name__.lower() - # Exception: models with different cache formats. These are limited to `DynamicCache` until their - # cache format is standardized, to avoid adding complexity to the codebase. - if "gptbigcode" in model_class: - if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)): - raise ValueError( - f"Using an unsupported cache format with {model_class}. Currently, it only supports the " - "legacy tuple format or `DynamicCache`" - ) - past_key_values = self._reorder_cache(past_key_values, beam_idx) - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - # Standard code path: use the `Cache.reorder_cache` - else: - past_key_values.reorder_cache(beam_idx) - return past_key_values - @staticmethod def _flatten_beam_dim(tensor: torch.Tensor) -> torch.Tensor: """[batch_size, num_beams, ...] -> [batch_size * num_beams, ...]""" @@ -4174,10 +4134,8 @@ def _beam_search( # pluck the cache from the beam indices that will be used in the next iteration if model_kwargs.get("past_key_values", None) is not None: - model_kwargs["past_key_values"] = self._temporary_reorder_cache( - past_key_values=model_kwargs["past_key_values"], - beam_idx=self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len]), - ) + beam_idx = self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len]) + model_kwargs["past_key_values"].reorder_cache(beam_idx) cur_len = cur_len + 1 this_peer_finished = not self._beam_search_has_unfinished_sequences( @@ -4475,9 +4433,7 @@ def _group_beam_search( del outputs if model_kwargs.get("past_key_values", None) is not None: - model_kwargs["past_key_values"] = self._temporary_reorder_cache( - model_kwargs["past_key_values"], reordering_indices - ) + model_kwargs["past_key_values"].reorder_cache(reordering_indices) # increase cur_len cur_len = cur_len + 1 @@ -4712,9 +4668,7 @@ def _constrained_beam_search( del outputs if model_kwargs.get("past_key_values", None) is not None: - model_kwargs["past_key_values"] = self._temporary_reorder_cache( - model_kwargs["past_key_values"], beam_idx - ) + model_kwargs["past_key_values"].reorder_cache(beam_idx) if return_dict_in_generate and output_scores: beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) @@ -4939,8 +4893,7 @@ def _assisted_decoding( new_cur_len = input_ids.shape[1] # 4.2. Discard past key values relative to unused assistant tokens - new_cache_size = new_cur_len - 1 - outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) + outputs.past_key_values.crop(new_cur_len - 1) # 5. Update the candidate generation strategy if needed candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) From 1424600b41aaf7da607705c6d3b7d932e8eb0428 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 9 Jun 2025 12:45:50 +0200 Subject: [PATCH 08/58] fix some models --- src/transformers/models/mvp/modeling_mvp.py | 12 +- .../models/roberta/modeling_roberta.py | 2 +- src/transformers/models/xlm/modeling_xlm.py | 114 ++++++++---------- tests/generation/test_utils.py | 4 +- 4 files changed, 57 insertions(+), 75 deletions(-) diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 736b30084d18..135d844827c3 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -1355,17 +1355,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -1782,6 +1771,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 298e342b39dc..b5062208efc6 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -341,7 +341,7 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - value_layer = self.transpose_for_scores(self.key(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py index df017e27b18b..3ef96b641826 100755 --- a/src/transformers/models/xlm/modeling_xlm.py +++ b/src/transformers/models/xlm/modeling_xlm.py @@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import gelu, get_activation +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, @@ -500,6 +501,7 @@ def __init__(self, n_heads, dim, config): self.layer_id = next(MultiHeadAttention.NEW_ID) self.dim = dim self.n_heads = n_heads + self.head_dim = dim // n_heads self.dropout = config.attention_dropout assert self.dim % self.n_heads == 0 @@ -524,50 +526,57 @@ def prune_heads(self, heads): self.dim = attention_head_size * self.n_heads self.pruned_heads = self.pruned_heads.union(heads) - def forward(self, input, mask, kv=None, cache=None, head_mask=None, output_attentions=False): + def forward( + self, + input, + mask, + kv=None, + cache=None, + head_mask=None, + output_attentions=False, + cache_position=None, + ): """ Self-attention (if kv is None) or attention over source sentence (provided by kv). """ # Input is (bs, qlen, dim) # Mask is (bs, klen) (non-causal) or (bs, klen, klen) bs, qlen, dim = input.size() - if kv is None: - klen = qlen if cache is None else cache["slen"] + qlen - else: - klen = kv.size(1) - # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' - n_heads = self.n_heads - dim_per_head = self.dim // n_heads - mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen) - - def shape(x): - """projection""" - return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2) - - def unshape(x): - """compute context""" - return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) - - q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head) - if kv is None: - k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head) - v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head) - elif cache is None or self.layer_id not in cache: - k = v = kv - k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head) - v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head) + is_cross_attention = kv is not None + mask_reshape = (bs, 1, qlen, -1) if mask.dim() == 3 else (bs, 1, 1, -1) + q = self.q_lin(input).view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2) if cache is not None: - if self.layer_id in cache: - if kv is None: - k_, v_ = cache[self.layer_id] - k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head) - v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head) + if isinstance(cache, EncoderDecoderCache): + is_updated = cache.is_updated.get(self.layer_id) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = cache.cross_attention_cache else: - k, v = cache[self.layer_id] - cache[self.layer_id] = (k, v) + curr_past_key_value = cache.self_attention_cache + else: + curr_past_key_value = cache - q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head) + current_states = kv if is_cross_attention else input + if is_cross_attention and cache is not None and is_updated: + # reuse k,v, cross_attentions + k = curr_past_key_value.key_cache[self.layer_id] + v = curr_past_key_value.value_cache[self.layer_id] + else: + k = self.k_lin(current_states) + v = self.v_lin(current_states) + k = k.view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2) + v = v.view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2) + + if cache is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + k, v = curr_past_key_value.update(k, v, self.layer_id, {"cache_position": cache_position}) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + cache.is_updated[self.layer_id] = True + + q = q / math.sqrt(self.head_dim) # (bs, n_heads, qlen, head_dim) scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen) mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen) scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen) @@ -579,8 +588,8 @@ def unshape(x): if head_mask is not None: weights = weights * head_mask - context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) - context = unshape(context) # (bs, qlen, dim) + context = torch.matmul(weights, v) # (bs, n_heads, qlen, head_dim) + context = context.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.head_dim) outputs = (self.out_lin(context),) if output_attentions: @@ -793,6 +802,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, # Dummy kwargs for now ) -> Union[Tuple, BaseModelOutput]: r""" @@ -829,45 +839,38 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device + if not isinstance(cache, Cache): + cache = EncoderDecoderCache.from_legacy_cache(cache) + if lengths is None: if input_ids is not None: lengths = (input_ids != self.pad_index).sum(dim=1).long() else: lengths = torch.tensor([slen] * bs, device=device) - # mask = input_ids != self.pad_index # check inputs assert lengths.size(0) == bs assert lengths.max().item() <= slen - # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 - # assert (src_enc is None) == (src_len is None) - # if src_enc is not None: - # assert self.is_decoder - # assert src_enc.size(0) == bs # generate masks mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask) - # if self.is_decoder and src_enc is not None: - # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] # position_ids if position_ids is None: position_ids = self.position_ids[:, :slen] else: assert position_ids.size() == (bs, slen) # (slen, bs) - # position_ids = position_ids.transpose(0, 1) # langs if langs is not None: assert langs.size() == (bs, slen) # (slen, bs) - # langs = langs.transpose(0, 1) # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.n_layers) # do not recompute cached elements if cache is not None and input_ids is not None: - _slen = slen - cache["slen"] + _slen = slen - cache.get_seq_length() input_ids = input_ids[:, -_slen:] position_ids = position_ids[:, -_slen:] if langs is not None: @@ -902,6 +905,7 @@ def forward( cache=cache, head_mask=head_mask[i], output_attentions=output_attentions, + cache_position=cache_position, ) attn = attn_outputs[0] if output_attentions: @@ -910,13 +914,6 @@ def forward( tensor = tensor + attn tensor = self.layer_norm1[i](tensor) - # encoder attention (for decoder only) - # if self.is_decoder and src_enc is not None: - # attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache) - # attn = nn.functional.dropout(attn, p=self.dropout, training=self.training) - # tensor = tensor + attn - # tensor = self.layer_norm15[i](tensor) - # FFN tensor = tensor + self.ffns[i](tensor) tensor = self.layer_norm2[i](tensor) @@ -926,13 +923,6 @@ def forward( if output_hidden_states: hidden_states = hidden_states + (tensor,) - # update cache length - if cache is not None: - cache["slen"] += tensor.size(1) - - # move back sequence length to dimension 0 - # tensor = tensor.transpose(0, 1) - if not return_dict: return tuple(v for v in [tensor, hidden_states, attentions] if v is not None) return BaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions) @@ -1034,6 +1024,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple, MaskedLMOutput]: r""" @@ -1076,6 +1067,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, **kwargs, ) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 91de2cd8a6d6..aa6768f66a1f 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -545,8 +545,8 @@ def test_greedy_generate_dict_outputs_use_cache(self): if self.has_attentions: config._attn_implementation = "eager" # can't output attentions otherwise - # if not hasattr(config.get_text_config(), "use_cache"): - # self.skipTest(reason=f"{model_class.__name__} doesn't support caching") + if not hasattr(config.get_text_config(), "use_cache"): + self.skipTest(reason=f"{model_class.__name__} doesn't support caching") if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes") From 02fb0d2b4ec178b71d5d694c5b728fe1f1764303 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 9 Jun 2025 16:20:48 +0200 Subject: [PATCH 09/58] fix some models --- .../generation/candidate_generator.py | 2 +- src/transformers/generation/utils.py | 2 +- .../models/align/modeling_align.py | 4 +- .../models/altclip/modeling_altclip.py | 12 +- .../models/autoformer/modeling_autoformer.py | 160 +++++++++-------- src/transformers/models/bart/modeling_bart.py | 11 -- src/transformers/models/bert/modeling_bert.py | 31 ++-- .../modeling_bert_generation.py | 27 +-- .../models/big_bird/modeling_big_bird.py | 47 +++-- .../modeling_tf_blenderbot_small.py | 10 -- .../models/blip/modeling_blip_text.py | 21 ++- .../bridgetower/modeling_bridgetower.py | 70 ++++---- src/transformers/models/bros/modeling_bros.py | 146 +++++++++------- .../models/camembert/modeling_camembert.py | 9 +- .../chinese_clip/modeling_chinese_clip.py | 13 +- src/transformers/models/clap/modeling_clap.py | 13 +- .../data2vec/modeling_data2vec_audio.py | 10 -- .../models/data2vec/modeling_data2vec_text.py | 31 ++-- .../models/electra/modeling_electra.py | 13 +- .../models/ernie/modeling_ernie.py | 27 +-- src/transformers/models/esm/modeling_esm.py | 162 +++++++++++------- .../models/layoutlm/modeling_layoutlm.py | 4 +- src/transformers/models/lilt/modeling_lilt.py | 11 +- .../models/markuplm/modeling_markuplm.py | 4 +- .../megatron_bert/modeling_megatron_bert.py | 13 +- .../models/musicgen/modeling_musicgen.py | 142 ++++++++------- .../modeling_musicgen_melody.py | 135 ++++++++------- .../models/rembert/modeling_rembert.py | 31 ++-- .../models/roberta/modeling_roberta.py | 30 ++-- .../modeling_roberta_prelayernorm.py | 27 +-- .../models/roc_bert/modeling_roc_bert.py | 33 ++-- .../models/roformer/modeling_roformer.py | 40 ++--- .../seamless_m4t/modeling_seamless_m4t.py | 2 +- .../modeling_seamless_m4t_v2.py | 12 +- .../models/speecht5/modeling_speecht5.py | 4 +- .../models/splinter/modeling_splinter.py | 13 +- .../models/superglue/modeling_superglue.py | 4 +- .../xlm_roberta/modeling_xlm_roberta.py | 9 +- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 32 ++-- .../models/xlnet/modeling_xlnet.py | 4 +- src/transformers/models/xmod/modeling_xmod.py | 27 +-- tests/models/mvp/test_modeling_mvp.py | 6 +- .../models/speecht5/test_modeling_speecht5.py | 6 +- 43 files changed, 755 insertions(+), 655 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 6e8c74ad5a73..24524e4b459f 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -28,7 +28,6 @@ if is_sklearn_available(): from sklearn.metrics import roc_curve -from ..cache_utils import Cache from ..pytorch_utils import isin_mps_friendly from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor, SuppressTokensLogitsProcessor @@ -1177,6 +1176,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, base_model.config.num_hidden_layers = original_num_hidden_layers return candidate_ids, candidate_logits + def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]: """Expands or crops the model's mask for decoding purposes, to the defined length""" diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ededb2eb2a75..7235941c76a6 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1544,7 +1544,7 @@ def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer): ) def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): - """Validates model kwargs for generation. Generate argument typos will also be caught here.""" + """Validates model kwargs for generation. Generate argument typos will also be caught here.""" # Excludes arguments that are handled before calling any model function if self.config.is_encoder_decoder: for key in ["decoder_input_ids"]: diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 0a9d52b2fa74..c0c37d1fca04 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -658,8 +658,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 5c983f05de87..c14676ef8552 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -249,8 +249,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -1231,7 +1231,13 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 0a41692f69cd..90f879313d42 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -26,6 +26,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -458,10 +459,11 @@ def __init__( self, embed_dim: int, num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - autocorrelation_factor: int = 3, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + autocorrelation_factor: Optional[int] = 3, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -476,6 +478,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -484,69 +487,63 @@ def __init__( self.autocorrelation_factor = autocorrelation_factor - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.reshape(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) # (1) period-based dependencies discovery # Resize (truncation or zero filling) @@ -745,7 +742,7 @@ def forward( class AutoformerDecoderLayer(nn.Module): - def __init__(self, config: AutoformerConfig): + def __init__(self, config: AutoformerConfig, layer_idx=None): super().__init__() self.embed_dim = config.d_model @@ -755,6 +752,7 @@ def __init__(self, config: AutoformerConfig): dropout=config.attention_dropout, is_decoder=True, autocorrelation_factor=config.autocorrelation_factor, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -767,6 +765,7 @@ def __init__(self, config: AutoformerConfig): dropout=config.attention_dropout, is_decoder=True, autocorrelation_factor=config.autocorrelation_factor, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -796,9 +795,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -824,15 +824,13 @@ def forward( residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -841,20 +839,18 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -862,9 +858,6 @@ def forward( # added layer norm here as an improvement hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) @@ -886,7 +879,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -1092,7 +1085,9 @@ def __init__(self, config: AutoformerConfig): self.embed_positions = AutoformerSinusoidalPositionalEmbedding( config.context_length + config.prediction_length, config.d_model ) - self.layers = nn.ModuleList([AutoformerDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList( + [AutoformerDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)] + ) self.layernorm_embedding = nn.LayerNorm(config.d_model) # https://github.com/thuml/Autoformer/blob/e6371e24f2ae2dd53e472edefdd5814c5176f864/models/Autoformer.py#L74 @@ -1116,6 +1111,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, AutoFormerDecoderOutput]: r""" Args: @@ -1188,6 +1184,22 @@ def forward( input_shape = inputs_embeds.size()[:-1] + if self.gradient_checkpointing and use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -1206,7 +1218,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1226,14 +1238,7 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, @@ -1245,6 +1250,7 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -1256,15 +1262,16 @@ def forward( cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) (hidden_states, residual_trend) = layer_outputs[0] trend = trend + residual_trend if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1280,6 +1287,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -1491,6 +1501,7 @@ def forward( output_attentions: Optional[bool] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[AutoformerModelOutput, Tuple]: r""" past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): @@ -1672,6 +1683,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) else: decoder_outputs = AutoFormerDecoderOutput() diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 2442baa24363..3477bc957da2 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1568,17 +1568,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 1c9317d844d0..e0e79d2f9679 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -227,7 +227,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -338,7 +337,6 @@ def forward( attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -357,7 +355,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position, @@ -372,8 +369,6 @@ def forward( is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states - attention_mask = encoder_attention_mask if is_cross_attention else attention_mask - if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): is_updated = past_key_value.is_updated.get(self.layer_idx) @@ -493,7 +488,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -503,7 +497,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position=cache_position, @@ -570,8 +563,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -593,12 +586,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] @@ -968,8 +960,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index c0487490f74f..43838ce0dae6 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -92,7 +92,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -231,7 +230,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -241,7 +239,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position=cache_position, @@ -313,8 +310,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -336,12 +333,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] @@ -708,8 +704,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 18e43e7c2bb4..f6d45bc7ca89 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -338,6 +338,7 @@ def forward( # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -356,8 +357,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -1375,12 +1376,12 @@ def forward( if self.attention_type == "original_full": self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) else: @@ -1433,13 +1434,13 @@ def __init__(self, config, seed=None): self.attention_type = config.attention_type self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = BigBirdAttention(config, seed=seed, layer_idx=seed) + self.attention = BigBirdAttention(config, seed=seed) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise TypeError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = BigBirdAttention(config, layer_idx=seed) + self.crossattention = BigBirdAttention(config, seed=seed) self.intermediate = BigBirdIntermediate(config) self.output = BigBirdOutput(config) @@ -1475,8 +1476,8 @@ def forward( # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, @@ -1505,13 +1506,12 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - cache_position, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights @@ -1603,7 +1603,6 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -1617,7 +1616,7 @@ def forward( from_mask, to_mask, blocked_encoder_mask, - past_key_value, + past_key_values, output_attentions, cache_position, ) @@ -1632,7 +1631,7 @@ def forward( from_mask, to_mask, blocked_encoder_mask, - past_key_value, + past_key_values, output_attentions, cache_position, ) @@ -1950,8 +1949,8 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + # Model doesn't have pure self attention, so no cache is expected + past_key_values_length = 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py index 4de98280836d..d16c9bd7494a 100644 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py @@ -210,16 +210,6 @@ def call( key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - if self.is_decoder: - # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) key_states = tf.reshape(key_states, proj_shape) diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 14d9cd9ae6e7..4f399aeca48f 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -138,6 +138,11 @@ def save_attention_map(self, attention_map): def get_attention_map(self): return self.attention_map + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + def forward( self, hidden_states: torch.Tensor, @@ -155,6 +160,7 @@ def forward( # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -173,8 +179,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -339,7 +345,7 @@ def __init__(self, config, layer_num): self.config = config self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = BlipTextAttention(config) + self.attention = BlipTextAttention(config, layer_idx=layer_num) self.layer_num = layer_num if self.config.is_decoder: self.crossattention = BlipTextAttention( @@ -752,8 +758,13 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds") - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length))).to(device) diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 45661e9f3e1d..3b4d03372828 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -449,7 +449,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -478,8 +477,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -588,7 +587,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -598,7 +596,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position=cache_position, @@ -609,14 +606,14 @@ def forward( class BridgeTowerBertCrossLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = BridgeTowerAttention(config) + self.attention = BridgeTowerAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention - self.crossattention = BridgeTowerAttention(config) + self.crossattention = BridgeTowerAttention(config, layer_idx=layer_idx) self.intermediate = BridgeTowerIntermediate(config) self.output = BridgeTowerOutput(config) @@ -629,6 +626,7 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attention_outputs = self.attention( @@ -646,12 +644,12 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask=attention_mask, + attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] # add cross attentions if we output attention weights @@ -671,17 +669,17 @@ def feed_forward_chunk(self, attention_output): class BridgeTowerTextLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = BridgeTowerAttention(config) + self.attention = BridgeTowerAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = BridgeTowerAttention(config, position_embedding_type="absolute") + self.crossattention = BridgeTowerAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = BridgeTowerIntermediate(config) self.output = BridgeTowerOutput(config) @@ -692,28 +690,27 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -721,24 +718,18 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -746,7 +737,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -1114,8 +1105,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) @@ -1220,10 +1216,10 @@ def __init__(self, config): ln.bias.data = self.vision_model.visual.ln_post.bias.data self.cross_modal_image_layers = nn.ModuleList( - [BridgeTowerBertCrossLayer(text_config) for _ in range(config.num_hidden_layers)] + [BridgeTowerBertCrossLayer(text_config, layer_idx=i) for i in range(config.num_hidden_layers)] ) self.cross_modal_text_layers = nn.ModuleList( - [BridgeTowerBertCrossLayer(text_config) for _ in range(config.num_hidden_layers)] + [BridgeTowerBertCrossLayer(text_config, layer_idx=i) for i in range(config.num_hidden_layers)] ) # Class token => Linear => Tanh diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index cadd1225427a..16fd32e8beaa 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -192,7 +193,7 @@ def forward( class BrosSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -215,6 +216,7 @@ def __init__(self, config): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor): new_x_shape = x.size()[:-1] + ( @@ -232,8 +234,9 @@ def forward( head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[torch.Tensor] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -242,36 +245,38 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -345,9 +350,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class BrosAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() - self.self = BrosSelfAttention(config) + self.self = BrosSelfAttention(config, layer_idx=layer_idx) self.output = BrosSelfOutput(config) self.pruned_heads = set() @@ -379,9 +384,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states=hidden_states, @@ -389,9 +394,9 @@ def forward( attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -429,17 +434,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class BrosLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = BrosAttention(config) + self.attention = BrosAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise Exception(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = BrosAttention(config) + self.crossattention = BrosAttention(config, layer_idx) self.intermediate = BrosIntermediate(config) self.output = BrosOutput(config) @@ -451,53 +456,46 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, bbox_pos_emb=bbox_pos_emb, attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if hasattr(self, "crossattention"): raise Exception( f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, @@ -508,7 +506,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -522,7 +520,7 @@ class BrosEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.layer = nn.ModuleList([BrosLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([BrosLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) def forward( self, @@ -537,18 +535,28 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if getattr(self.config, "gradient_checkpointing", False) and self.training: if use_cache: @@ -565,7 +573,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, + None, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -575,13 +585,14 @@ def forward( head_mask=layer_head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -590,12 +601,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -604,7 +619,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -732,6 +747,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: r""" bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'): @@ -782,8 +798,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(input_shape, device=device) @@ -844,6 +865,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 5a96c2234771..068479c2aaf5 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -881,8 +881,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 7a8948a8045a..0df7b4d91067 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -308,8 +308,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -1114,8 +1114,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 912cce28b5e8..754512ed25c9 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -1186,8 +1186,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -1739,8 +1739,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index eafcbff89ae5..12a7fdf4ef7f 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -300,16 +300,6 @@ def forward( key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 757b5cd4ef30..9a2a6cc37dd9 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -177,7 +177,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -331,7 +330,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -341,7 +339,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position=cache_position, @@ -413,8 +410,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -436,12 +433,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] @@ -680,6 +676,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -705,8 +702,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) @@ -759,6 +761,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None @@ -819,6 +822,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" @@ -861,6 +865,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = outputs[0] diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 05eb5f9ec850..09caf8ccbc63 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -267,8 +267,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -787,8 +787,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(input_shape, device=device) diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 8c1ef579a8e0..4fd701b4b64e 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -163,7 +163,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -317,7 +316,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -327,7 +325,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position=cache_position, @@ -397,8 +394,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -420,12 +417,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] @@ -804,8 +800,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 7ec40a852871..8a1a142f9a77 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -249,7 +250,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): class EsmSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.config = config @@ -279,6 +280,7 @@ def __init__(self, config, position_embedding_type=None): self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -294,6 +296,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -302,23 +305,36 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) @@ -328,16 +344,6 @@ def forward( # ESM code and fix rotary embeddings. query_layer = query_layer * self.attention_head_size**-0.5 - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - if self.position_embedding_type == "rotary": query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) @@ -408,8 +414,8 @@ class EsmFlashAttention2(EsmSelfAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - def __init__(self, config, position_embedding_type=None): - super().__init__(config, position_embedding_type=position_embedding_type) + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. @@ -426,6 +432,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: # Flash attention doesn't support output_attentions or cross attention if output_attentions or head_mask is not None or encoder_hidden_states is not None: @@ -450,8 +457,13 @@ def forward( key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) if past_key_value is not None: - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + if isinstance(past_key_value, EncoderDecoderCache): + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need @@ -520,9 +532,9 @@ def forward( class EsmAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() - self.self = ESM_ATTENTION_CLASSES[config._attn_implementation](config) + self.self = ESM_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) self.output = EsmSelfOutput(config) self.pruned_heads = set() self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -554,6 +566,7 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): hidden_states_ln = self.LayerNorm(hidden_states) self_outputs = self.self( @@ -564,6 +577,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -595,17 +609,17 @@ def forward(self, hidden_states, input_tensor): class EsmLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = EsmAttention(config) + self.attention = EsmAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = EsmAttention(config) + self.crossattention = EsmAttention(config, layer_idx) self.intermediate = EsmIntermediate(config) self.output = EsmOutput(config) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -619,26 +633,25 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise AttributeError( @@ -646,31 +659,25 @@ def forward( " with cross-attention layers by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + cross_attn_past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = self.feed_forward_chunk(attention_output) outputs = (layer_output,) + outputs # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs def feed_forward_chunk(self, attention_output): @@ -684,7 +691,7 @@ class EsmEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([EsmLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.gradient_checkpointing = False @@ -700,25 +707,35 @@ def forward( output_attentions=False, output_hidden_states=False, return_dict=True, + cache_position=None, ): if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + next_decoder_cache = None all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -728,8 +745,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -738,13 +756,14 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache = next_decoder_cache + (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -756,12 +775,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -770,7 +793,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -884,6 +907,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: r""" input_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`): @@ -927,8 +951,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) @@ -977,6 +1006,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index de605c6f498d..0dd38c8af060 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -189,8 +189,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index d56a1d9ef85c..ca78a1421d19 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -185,7 +185,7 @@ def forward(self, bbox=None, position_ids=None): class LiltSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -220,6 +220,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.channel_shrink_ratio = config.channel_shrink_ratio + self.layer_idx = layer_idx def transpose_for_scores(self, x, r=1): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size // r) @@ -337,9 +338,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class LiltAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() - self.self = LiltSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = LiltSelfAttention(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) self.output = LiltSelfOutput(config) self.pruned_heads = set() @@ -420,11 +421,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class LiltLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = LiltAttention(config) + self.attention = LiltAttention(config, layer_idx=layer_idx) self.intermediate = LiltIntermediate(config) self.output = LiltOutput(config) diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index bfc070459e2c..b01e986dc5b2 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -395,8 +395,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 620450abfb25..7d5a25ec67a2 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -245,8 +245,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -834,8 +834,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 295be0da2d3d..d539db21f0fa 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import ( ClassifierFreeGuidanceLogitsProcessor, GenerationConfig, @@ -194,6 +195,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[MusicgenConfig] = None, + layer_idx: Optional[bool] = None, ): super().__init__() self.embed_dim = embed_dim @@ -210,6 +212,7 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -220,10 +223,11 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -244,42 +248,35 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -305,7 +302,7 @@ def forward( class MusicgenDecoderLayer(nn.Module): - def __init__(self, config: MusicgenDecoderConfig): + def __init__(self, config: MusicgenDecoderConfig, layer_idx=None): super().__init__() self.embed_dim = config.hidden_size @@ -317,6 +314,7 @@ def __init__(self, config: MusicgenDecoderConfig): bias=False, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -330,6 +328,7 @@ def __init__(self, config: MusicgenDecoderConfig): is_decoder=True, bias=False, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False) @@ -346,9 +345,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -372,42 +372,35 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -423,7 +416,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -474,7 +467,9 @@ def __init__(self, config: MusicgenDecoderConfig): config.hidden_size, ) - self.layers = nn.ModuleList([MusicgenDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [MusicgenDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) self.layer_norm = nn.LayerNorm(config.hidden_size) self.attn_implementation = config._attn_implementation @@ -503,6 +498,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`): @@ -562,7 +558,23 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if inputs_embeds is None: @@ -583,23 +595,14 @@ def forward( # embed positions positions = self.embed_positions(input, past_key_values_length) - hidden_states = inputs_embeds + positions.to(inputs_embeds.device) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -617,8 +620,6 @@ def forward( if self.training and (dropout_probability < self.layerdrop): continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.forward, @@ -631,6 +632,7 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -642,14 +644,15 @@ def forward( cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -664,6 +667,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -785,6 +791,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`): @@ -844,6 +851,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -911,6 +919,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" @@ -973,6 +982,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 92c3a8372b29..12258e5a3842 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import ( ClassifierFreeGuidanceLogitsProcessor, GenerationConfig, @@ -210,6 +211,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[MusicgenMelodyConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -226,6 +228,7 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -236,10 +239,11 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -260,42 +264,35 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -321,7 +318,7 @@ def forward( class MusicgenMelodyDecoderLayer(nn.Module): - def __init__(self, config: MusicgenMelodyDecoderConfig): + def __init__(self, config: MusicgenMelodyDecoderConfig, layer_idx=None): super().__init__() self.embed_dim = config.hidden_size @@ -333,6 +330,7 @@ def __init__(self, config: MusicgenMelodyDecoderConfig): bias=False, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -349,9 +347,10 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -368,15 +367,13 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -396,7 +393,7 @@ def forward( outputs += (self_attn_weights,) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -449,7 +446,9 @@ def __init__(self, config: MusicgenMelodyDecoderConfig): config.hidden_size, ) - self.layers = nn.ModuleList([MusicgenMelodyDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [MusicgenMelodyDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) self.layer_norm = nn.LayerNorm(config.hidden_size) self.attn_implementation = config._attn_implementation @@ -478,6 +477,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: r""" input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`): @@ -531,9 +531,24 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if inputs_embeds is None: inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) @@ -561,22 +576,13 @@ def forward( # embed positions positions = self.embed_positions(inputs_embeds, past_key_values_length) - hidden_states = inputs_embeds + positions.to(inputs_embeds.device) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask has a correct number of layers specified if desired if head_mask is not None: @@ -594,8 +600,6 @@ def forward( if self.training and (dropout_probability < self.layerdrop): continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.forward, @@ -605,20 +609,22 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_attentions += (layer_outputs[1],) @@ -630,6 +636,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None) return BaseModelOutputWithPast( @@ -724,6 +733,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: r""" input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`): @@ -776,6 +786,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -843,6 +854,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, MusicgenMelodyOutputWithPast]: r""" input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`): @@ -897,6 +909,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index a38bbbc767c8..fad1c6f3d52f 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -231,7 +231,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, @@ -260,8 +259,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -355,7 +354,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -365,7 +363,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position=cache_position, @@ -435,8 +432,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -458,12 +455,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] @@ -751,8 +747,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index b5062208efc6..d6c770595090 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -176,7 +176,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -288,7 +287,6 @@ def forward( attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -307,7 +305,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position, @@ -322,7 +319,6 @@ def forward( is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states - attention_mask = encoder_attention_mask if is_cross_attention else attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -445,7 +441,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -455,7 +450,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position=cache_position, @@ -525,8 +519,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -548,12 +542,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] @@ -821,8 +814,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 88f23c2c9f6e..8aaf88b4c51a 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -175,7 +175,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -321,7 +320,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -332,7 +330,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position, @@ -403,8 +400,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -426,12 +423,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] @@ -708,8 +704,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 1f15223b4d45..86f23459f048 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -290,7 +290,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -319,8 +318,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -444,7 +443,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -454,7 +452,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position=cache_position, @@ -501,7 +498,7 @@ def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = RoCBertAttention(config, layer_idx) + self.attention = RoCBertAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: @@ -524,8 +521,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -547,12 +544,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] @@ -899,8 +895,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 21b8c42d6e21..0f148a3d3b6d 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -222,7 +222,6 @@ def forward( sinusoidal_pos=None, head_mask=None, encoder_hidden_states=None, - encoder_attention_mask=None, past_key_value=None, output_attentions=False, cache_position=None, @@ -251,8 +250,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -372,7 +371,6 @@ def forward( sinusoidal_pos=None, head_mask=None, encoder_hidden_states=None, - encoder_attention_mask=None, past_key_value=None, output_attentions=False, cache_position=None, @@ -383,7 +381,6 @@ def forward( sinusoidal_pos, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position, @@ -453,9 +450,9 @@ def forward( ): self_attention_outputs = self.attention( hidden_states, - attention_mask, - sinusoidal_pos, - head_mask, + attention_mask=attention_mask, + sinusoidal_pos=sinusoidal_pos, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -477,13 +474,12 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - sinusoidal_pos, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + sinusoidal_pos=sinusoidal_pos, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] @@ -562,7 +558,6 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -573,7 +568,7 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, cache_position, ) @@ -585,7 +580,7 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, cache_position, ) @@ -887,8 +882,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index e4fdf2e17efe..8e5d5f098a4e 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -1091,7 +1091,7 @@ def forward( proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states.view(*proj_shape) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index 8587780aa683..e7f0f4177227 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -987,20 +987,10 @@ def forward( past_key_value.is_updated[self.layer_idx] = True query_states = self.q_proj(hidden_states) - query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.reshape(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states * self.scaling attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - if attention_mask is not None: attention_scores = attention_scores + attention_mask diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 7d522972410f..6ee2578b019d 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -815,7 +815,7 @@ def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, ): if input_ids is not None: input_shape = input_ids.size() @@ -823,7 +823,7 @@ def forward( else: raise ValueError("You have to specify `decoder_input_ids`") - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = past_key_values[0][0].shape[-2] if past_key_values is not None else 0 positions = self.embed_positions(input_ids, past_key_values_length) inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 16776b0a3719..f962446d5eb4 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -161,8 +161,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -638,8 +638,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/superglue/modeling_superglue.py b/src/transformers/models/superglue/modeling_superglue.py index 651d442fa679..44828c6bdec5 100644 --- a/src/transformers/models/superglue/modeling_superglue.py +++ b/src/transformers/models/superglue/modeling_superglue.py @@ -301,8 +301,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 9a12e56c517b..6b460d023dd5 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -812,8 +812,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index b2fbb4dc03e5..aaa0a803c018 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -174,7 +174,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -286,7 +285,6 @@ def forward( attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -305,7 +303,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position, @@ -320,7 +317,6 @@ def forward( is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states - attention_mask = encoder_attention_mask if is_cross_attention else attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -339,7 +335,7 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - value_layer = self.transpose_for_scores(self.key(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: @@ -441,7 +437,6 @@ def forward( attention_mask=None, head_mask=None, encoder_hidden_states=None, - encoder_attention_mask=None, past_key_value=None, output_attentions=False, cache_position=None, @@ -452,7 +447,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position, @@ -520,8 +514,8 @@ def forward( ): self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -543,12 +537,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] @@ -807,8 +800,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index 5a677f1d7161..a496cd6df7b3 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -187,7 +187,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path): class XLNetRelativeAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() if config.d_model % config.n_head != 0: @@ -214,6 +214,7 @@ def __init__(self, config): self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.dropout) + self.layer_idx = layer_idx def prune_heads(self, heads): raise NotImplementedError @@ -322,6 +323,7 @@ def forward( target_mapping=None, head_mask=None, output_attentions=False, + cache_position=None, ): if g is not None: # Two-stream attention with relative positional encoding. diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 1263b805977e..df5a6487fd96 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -174,7 +174,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -320,7 +319,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -333,7 +331,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position, @@ -459,8 +456,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -482,12 +479,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] @@ -803,8 +799,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if lang_ids is None: if self.config.default_language is None: diff --git a/tests/models/mvp/test_modeling_mvp.py b/tests/models/mvp/test_modeling_mvp.py index 9ee1077d7d9b..1dc16992633c 100644 --- a/tests/models/mvp/test_modeling_mvp.py +++ b/tests/models/mvp/test_modeling_mvp.py @@ -770,9 +770,9 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[ - "last_hidden_state" - ] + output_from_past = model( + next_tokens, attention_mask=attn_mask, past_key_values=past_key_values, use_cache=True + )["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() diff --git a/tests/models/speecht5/test_modeling_speecht5.py b/tests/models/speecht5/test_modeling_speecht5.py index 5fc1d6706666..95da4a641d7e 100644 --- a/tests/models/speecht5/test_modeling_speecht5.py +++ b/tests/models/speecht5/test_modeling_speecht5.py @@ -354,9 +354,9 @@ def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict): next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1) output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] - output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[ - "last_hidden_state" - ] + output_from_past = model( + next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values, use_cache=True + )["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() From f7494bcfb186a165d7e556924670fe7ec6cfd364 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 9 Jun 2025 16:51:01 +0200 Subject: [PATCH 10/58] fix mambas and support cache in tapas --- src/transformers/generation/utils.py | 4 +- .../seamless_m4t/modeling_seamless_m4t.py | 2 + .../modeling_seamless_m4t_v2.py | 4 +- .../models/tapas/modeling_tapas.py | 101 +++++++++++------- 4 files changed, 66 insertions(+), 45 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7235941c76a6..4ac06aed49d3 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1955,9 +1955,9 @@ def _supports_default_dynamic_cache(self) -> bool: order to save memory (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed for `HybridMambaAttentionDynamicCache`). """ - return ( + return all( special_model_name not in self.__class__.__name__.lower() - for special_model_name in ["jamba", "zamba", "mamba"] + for special_model_name in ["jamba", "zamba", "mamba", "bamba"] ) def _prepare_cache_for_generation( diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 8e5d5f098a4e..83a4a1faf465 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -3132,6 +3132,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -3192,6 +3193,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(decoder_outputs[0]) diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index e7f0f4177227..fde845446e16 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -2348,7 +2348,6 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: r""" @@ -2378,7 +2377,6 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) @@ -3404,6 +3402,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -3464,6 +3463,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(decoder_outputs[0]) diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index d409f6742b74..d8b76269ad68 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -26,7 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer @@ -285,7 +285,7 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs class TapasSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -303,6 +303,7 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -318,6 +319,7 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): mixed_query_layer = self.query(hidden_states) @@ -326,29 +328,38 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) - if self.is_decoder: - past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) @@ -395,9 +406,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class TapasAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() - self.self = TapasSelfAttention(config) + self.self = TapasSelfAttention(config, layer_idx=layer_idx) self.output = TapasSelfOutput(config) self.pruned_heads = set() @@ -427,19 +438,17 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) @@ -479,17 +488,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class TapasLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = TapasAttention(config) + self.attention = TapasAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = TapasAttention(config) + self.crossattention = TapasAttention(config, layer_idx=layer_idx) self.intermediate = TapasIntermediate(config) self.output = TapasOutput(config) @@ -507,8 +516,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -530,12 +539,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] @@ -563,7 +571,7 @@ class TapasEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.layer = nn.ModuleList([TapasLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([TapasLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -578,7 +586,16 @@ def forward( output_attentions=False, output_hidden_states=False, return_dict=True, + cache_position=None, ): + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None for i, layer_module in enumerate(self.layer): @@ -597,6 +614,7 @@ def forward( encoder_attention_mask, past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -607,6 +625,7 @@ def forward( encoder_attention_mask, past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] if output_attentions: From 576fb7b6c05882b7dc2def23d9b1c5276c7a1f62 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 10 Jun 2025 11:57:58 +0200 Subject: [PATCH 11/58] fix some more tests --- .../models/big_bird/modeling_big_bird.py | 58 +++++++----------- .../chinese_clip/modeling_chinese_clip.py | 2 +- src/transformers/models/clvp/modeling_clvp.py | 59 +++++-------------- src/transformers/models/fsmt/modeling_fsmt.py | 14 ++--- .../models/kosmos2/modeling_kosmos2.py | 2 +- src/transformers/models/rag/modeling_rag.py | 17 ++++++ .../models/speecht5/modeling_speecht5.py | 9 ++- .../models/whisper/generation_whisper.py | 7 +-- .../xlm_roberta/modeling_xlm_roberta.py | 8 +-- tests/models/bert/test_modeling_bert.py | 8 ++- .../models/big_bird/test_modeling_big_bird.py | 2 + .../prophetnet/test_modeling_prophetnet.py | 2 +- 12 files changed, 87 insertions(+), 101 deletions(-) diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index f6d45bc7ca89..ea5501b9093c 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -26,7 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -334,41 +334,22 @@ def forward( ): mixed_query_layer = self.query(hidden_states) - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - attention_mask = encoder_attention_mask if is_cross_attention else attention_mask - - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) - if is_cross_attention: - # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache - else: - curr_past_key_value = past_key_value.self_attention_cache - else: - curr_past_key_value = past_key_value - - current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + # NOTE: BigBird has only cross attention layers so we can ignore self attn path + if past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0: # reuse k,v, cross_attentions - key_layer = curr_past_key_value.key_cache[self.layer_idx] - value_layer = curr_past_key_value.value_cache[self.layer_idx] + key_layer = past_key_value.key_cache[self.layer_idx] + value_layer = past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not is_cross_attention else None - key_layer, value_layer = curr_past_key_value.update( - key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + key_layer, value_layer = past_key_value.update( + key_layer, + value_layer, + self.layer_idx, ) - # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls - if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True query_layer = self.transpose_for_scores(mixed_query_layer) @@ -376,9 +357,9 @@ def forward( attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) - if attention_mask is not None: + if encoder_attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BigBirdModel forward() function) - attention_scores = attention_scores + attention_mask + attention_scores = attention_scores + encoder_attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.functional.softmax(attention_scores, dim=-1) @@ -1590,11 +1571,11 @@ def forward( if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " - "You should pass an instance of `EncoderDecoderCache` instead, e.g. " - "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." ) return_legacy_cache = True - past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) next_decoder_cache = None @@ -1949,8 +1930,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # Model doesn't have pure self attention, so no cache is expected past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 0df7b4d91067..ae8ec65da77a 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -590,7 +590,7 @@ def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ChineseCLIPTextAttention(config, layer_idx) + self.attention = ChineseCLIPTextAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index a5377adb2a08..a2fc7ea564a2 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -1355,55 +1355,28 @@ def _prepare_model_inputs( return inputs, input_name, model_kwargs def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, inputs_embeds=None, conditioning_embeds=None, **kwargs + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + conditioning_embeds=None, + cache_position=None, + **kwargs, ): # Overwritten: has `conditioning_embeds`-related logic input_ids_length = input_ids.shape[-1] - token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs - if past_key_values: - past_length = past_key_values.get_seq_length() - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -input_ids.shape[1] :] - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids[:, -input_ids.shape[1] :] - else: - position_ids = None - if conditioning_embeds is not None and past_key_values is not None: - position_ids = torch.tensor([input_ids_length], dtype=torch.long, device=input_ids.device) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "token_type_ids": token_type_ids, - } + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, ) + if conditioning_embeds is not None and cache_position[0] != 0: + model_inputs["position_ids"] = torch.tensor([input_ids_length], dtype=torch.long, device=input_ids.device) + return model_inputs @auto_docstring diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 2cba9db9b592..cd8381903db3 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -802,6 +802,7 @@ def forward( else: curr_past_key_value = layer_state + # NOTE: FSMT has format (seq_len, BS, model_dim) ofr inputs current_states = key if self.encoder_decoder_attention else query if self.encoder_decoder_attention and layer_state is not None and is_updated: # reuse k,v, cross_attentions @@ -810,8 +811,8 @@ def forward( else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(-1, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3) + value_states = value_states.view(-1, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3) if layer_state is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation @@ -825,11 +826,10 @@ def forward( query_states = self.q_proj(query) * self.scaling - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + # Reshape back to 3D tensors for `bmm` + query_states = query_states.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + key_states = key_states.reshape(bsz * self.num_heads, -1, self.head_dim) + value_states = value_states.reshape(bsz * self.num_heads, -1, self.head_dim) assert key_states is not None src_len = key_states.size(1) diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index 86c3d0a65079..de9ad2c6cdf3 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -1542,7 +1542,7 @@ def prepare_inputs_for_generation( ) # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore - if cache_position[0] == 0: + if cache_position[0] != 0: image_embeds = None image_embeds_position_mask = None diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 2f9c3dbd9e20..809afd8866a8 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1558,6 +1558,15 @@ def extend_enc_output(tensor, num_beams=None): generation_config=generation_config, stopping_criteria=stopping_criteria ) + self._prepare_cache_for_generation( + generation_config, + model_kwargs, + assistant_model=None, + batch_size=input_ids.shape[0], + max_cache_length=generation_config.max_length - 1, + device=input_ids.device, + ) + if generation_config.num_beams == 1: if generation_config.num_return_sequences > 1: raise ValueError( @@ -1576,6 +1585,14 @@ def extend_enc_output(tensor, num_beams=None): elif generation_config.num_beams > 1: if generation_config.num_return_sequences > generation_config.num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") + + # 11. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) return self._beam_search( input_ids, logits_processor=pre_processor, diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 6ee2578b019d..35a61203f88f 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -823,7 +823,14 @@ def forward( else: raise ValueError("You have to specify `decoder_input_ids`") - past_key_values_length = past_key_values[0][0].shape[-2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) + positions = self.embed_positions(input_ids, past_key_values_length) inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 08552d411453..505118c19822 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1103,7 +1103,7 @@ def split_by_batch_index(values, key, batch_idx, is_shortform, beam_indices=None for v in [cache_cls.key_cache, cache_cls.value_cache]: layer_past_key_values.append(v[layer_idx][batch_idx][None].cpu()) all_past_key_values.append(tuple(layer_past_key_values)) - return tuple(all_past_key_values) + return EncoderDecoderCache.from_legacy_cache(tuple(all_past_key_values)) else: all_past_key_values = [] for v in range(len(values)): @@ -1150,7 +1150,6 @@ def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs): for i in range(len(seek_outputs[0][key])) ) elif key == "past_key_values": - past_key_value_type = kwargs.get("past_key_values") if seek_outputs[0][key] is not None: outputs[key] = tuple( tuple( @@ -1159,8 +1158,8 @@ def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs): ) for i in range(len(seek_outputs[0][key])) ) - if past_key_value_type is not None and isinstance(past_key_value_type, EncoderDecoderCache): - outputs[key] = past_key_value_type.from_legacy_cache(outputs[key]) + if isinstance(seek_outputs[0][key], EncoderDecoderCache): + outputs[key] = EncoderDecoderCache.from_legacy_cache(outputs[key]) else: outputs[key] = None diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 6b460d023dd5..2131376f2333 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -206,8 +206,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -342,8 +342,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - value_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation diff --git a/tests/models/bert/test_modeling_bert.py b/tests/models/bert/test_modeling_bert.py index a6aac8e3829a..c75175a811fb 100644 --- a/tests/models/bert/test_modeling_bert.py +++ b/tests/models/bert/test_modeling_bert.py @@ -737,9 +737,11 @@ def test_sdpa_ignored_mask(self): torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4) ) - # Case where query length != kv_length. - res_eager = model(**inp, past_key_values=pkv) - res_sdpa = model_sdpa(**inp, past_key_values=pkv) + # Case where query length != kv_length. Note that model needs to be a decoder so we can use cache + model.config.is_decoder = True + model_sdpa.config.is_decoder = True + res_eager = model(**inp, past_key_values=pkv, use_cache=True) + res_sdpa = model_sdpa(**inp, past_key_values=pkv, use_cache=True) self.assertTrue( torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4) ) diff --git a/tests/models/big_bird/test_modeling_big_bird.py b/tests/models/big_bird/test_modeling_big_bird.py index bdab0f73b653..8ec874d0f7a8 100644 --- a/tests/models/big_bird/test_modeling_big_bird.py +++ b/tests/models/big_bird/test_modeling_big_bird.py @@ -284,6 +284,7 @@ def create_and_check_decoder_model_past_large_inputs( attention_mask=next_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + use_cache=False, output_hidden_states=True, )["hidden_states"][0] output_from_past = model( @@ -292,6 +293,7 @@ def create_and_check_decoder_model_past_large_inputs( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, + use_cache=True, output_hidden_states=True, )["hidden_states"][0] diff --git a/tests/models/prophetnet/test_modeling_prophetnet.py b/tests/models/prophetnet/test_modeling_prophetnet.py index b9632b21bbe5..42c7a4bb5d2e 100644 --- a/tests/models/prophetnet/test_modeling_prophetnet.py +++ b/tests/models/prophetnet/test_modeling_prophetnet.py @@ -737,7 +737,7 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids)["last_hidden_state"] - output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values, use_cache=True)["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() From 8757e846bff3a59af49e1df03cada3434149aa18 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 10 Jun 2025 13:03:33 +0200 Subject: [PATCH 12/58] fix copies --- .../models/align/modeling_align.py | 20 ++-- .../models/altclip/modeling_altclip.py | 21 ++-- src/transformers/models/bart/modeling_bart.py | 9 -- .../models/bart/modeling_tf_bart.py | 2 +- src/transformers/models/bert/modeling_bert.py | 8 -- .../models/big_bird/modeling_big_bird.py | 5 +- .../modeling_bigbird_pegasus.py | 64 ++++------- .../blenderbot/modeling_tf_blenderbot.py | 2 +- .../modeling_tf_blenderbot_small.py | 12 ++- .../bridgetower/modeling_bridgetower.py | 2 +- .../models/camembert/modeling_camembert.py | 31 +++--- .../chinese_clip/modeling_chinese_clip.py | 18 ++-- src/transformers/models/clap/modeling_clap.py | 22 ++-- .../data2vec/modeling_data2vec_audio.py | 34 +----- .../models/electra/modeling_electra.py | 20 ++-- .../models/flaubert/modeling_flaubert.py | 102 +++++++++--------- .../models/hubert/modeling_hubert.py | 44 +------- .../models/layoutlm/modeling_layoutlm.py | 20 ++-- .../models/marian/modeling_tf_marian.py | 2 +- .../models/markuplm/modeling_markuplm.py | 20 ++-- .../models/mbart/modeling_tf_mbart.py | 2 +- .../megatron_bert/modeling_megatron_bert.py | 96 +++++++++-------- .../models/musicgen/modeling_musicgen.py | 1 - .../modeling_musicgen_melody.py | 2 +- .../models/nllb_moe/modeling_nllb_moe.py | 2 +- .../patchtsmixer/modeling_patchtsmixer.py | 44 +------- .../models/patchtst/modeling_patchtst.py | 44 +------- .../models/pegasus/modeling_tf_pegasus.py | 2 +- .../models/prophetnet/modeling_prophetnet.py | 4 +- src/transformers/models/sew/modeling_sew.py | 44 +------- .../speech_to_text/modeling_speech_to_text.py | 2 +- .../models/splinter/modeling_splinter.py | 20 ++-- .../models/superglue/modeling_superglue.py | 3 - .../models/tapas/modeling_tapas.py | 10 +- .../models/unispeech/modeling_unispeech.py | 44 +------- .../unispeech_sat/modeling_unispeech_sat.py | 44 +------- .../models/wav2vec2/modeling_wav2vec2.py | 44 +------- .../xlm_roberta/modeling_xlm_roberta.py | 23 ++-- 38 files changed, 279 insertions(+), 610 deletions(-) diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index c0c37d1fca04..8e6a4ef50463 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -629,7 +629,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -783,7 +782,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -793,7 +791,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position=cache_position, @@ -840,7 +837,7 @@ def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = AlignTextAttention(config, layer_idx) + self.attention = AlignTextAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: @@ -863,8 +860,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -886,12 +883,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index c14676ef8552..e7b389d4819c 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -220,7 +220,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -374,7 +373,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -384,7 +382,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position=cache_position, @@ -431,7 +428,7 @@ def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = AltRobertaAttention(config, layer_idx) + self.attention = AltRobertaAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: @@ -454,8 +451,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -477,12 +474,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] @@ -1230,7 +1226,6 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length past_key_values_length = 0 if past_key_values is not None: past_key_values_length = ( diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 3477bc957da2..c5269516eaf6 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1995,15 +1995,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = [ "BartForCausalLM", diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index a772f0fd2346..7ab9817986e6 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -1539,7 +1539,7 @@ def prepare_inputs_for_generation( if decoder_attention_mask is not None: # xla decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] elif past_key_values is not None: # no xla + past_key_values - decoder_position_ids = past_key_values.get_seq_length() + decoder_position_ids = past_key_values[0][0].shape[2] else: # no xla + no past_key_values decoder_position_ids = tf.range(decoder_input_ids.shape[1]) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index e0e79d2f9679..02917376ec72 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -1268,14 +1268,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class BertForMaskedLM(BertPreTrainedModel): diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index ea5501b9093c..46ffa4f76e2b 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -335,13 +335,14 @@ def forward( mixed_query_layer = self.query(hidden_states) # NOTE: BigBird has only cross attention layers so we can ignore self attn path + current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states if past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0: # reuse k,v, cross_attentions key_layer = past_key_value.key_cache[self.layer_idx] value_layer = past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index d49d4e65bd70..7e943aaecb32 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -107,7 +107,7 @@ def forward(self, input_ids: torch.Tensor): # Copied from transformers.models.big_bird.modeling_big_bird.BigBirdSelfAttention with BigBird->BigBirdPegasus class BigBirdPegasusSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -125,6 +125,7 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -140,51 +141,37 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): mixed_query_layer = self.query(hidden_states) - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: + # NOTE: BigBirdPegasus has only cross attention layers so we can ignore self attn path + current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + if past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = past_key_value.key_cache[self.layer_idx] + value_layer = past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + key_layer, value_layer = past_key_value.update( + key_layer, + value_layer, + self.layer_idx, + ) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + query_layer = self.transpose_for_scores(mixed_query_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) - if attention_mask is not None: + if encoder_attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BigBirdPegasusModel forward() function) - attention_scores = attention_scores + attention_mask + attention_scores = attention_scores + encoder_attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.functional.softmax(attention_scores, dim=-1) @@ -2686,17 +2673,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py index f3476cb925b6..23f817a03770 100644 --- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py @@ -1525,7 +1525,7 @@ def prepare_inputs_for_generation( if decoder_attention_mask is not None: # xla decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] elif past_key_values is not None: # no xla + past_key_values - decoder_position_ids = past_key_values[0][0].shape[2] + decoder_position_ids = past_key_values.get_seq_length() else: # no xla + no past_key_values decoder_position_ids = tf.range(decoder_input_ids.shape[1]) diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py index d16c9bd7494a..4c04c912b66d 100644 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py @@ -210,6 +210,16 @@ def call( key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) key_states = tf.reshape(key_states, proj_shape) @@ -1485,7 +1495,7 @@ def prepare_inputs_for_generation( if decoder_attention_mask is not None: # xla decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] elif past_key_values is not None: # no xla + past_key_values - decoder_position_ids = past_key_values[0][0].shape[2] + decoder_position_ids = past_key_values.get_seq_length() else: # no xla + no past_key_values decoder_position_ids = tf.range(decoder_input_ids.shape[1]) diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 3b4d03372828..993c94cc96cf 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -464,7 +464,7 @@ def forward( if isinstance(past_key_value, EncoderDecoderCache): is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: - # after the first generated id, we can subsequently re-use all key/value_states from cache + # after the first generated id, we can subsequently re-use all key/value_layer from cache curr_past_key_value = past_key_value.cross_attention_cache else: curr_past_key_value = past_key_value.self_attention_cache diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 068479c2aaf5..0843f1fdbeac 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -177,7 +177,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -206,8 +205,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -289,7 +288,6 @@ def forward( attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -308,7 +306,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position, @@ -323,7 +320,6 @@ def forward( is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states - attention_mask = encoder_attention_mask if is_cross_attention else attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -342,8 +338,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - value_layer = self.transpose_for_scores(self.k_proj(current_states)) - value_layer = self.transpose_for_scores(self.v_proj(current_states)) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -446,7 +442,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -456,7 +451,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position=cache_position, @@ -503,7 +497,7 @@ def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = CamembertAttention(config, layer_idx) + self.attention = CamembertAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: @@ -526,8 +520,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -549,12 +543,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index ae8ec65da77a..02651d26741c 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -279,7 +279,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -433,7 +432,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -443,7 +441,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position=cache_position, @@ -615,8 +612,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -638,12 +635,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 754512ed25c9..56b66a418b0d 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -1157,7 +1157,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -1311,7 +1310,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -1321,7 +1319,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position=cache_position, @@ -1368,7 +1365,7 @@ def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ClapTextAttention(config, layer_idx) + self.attention = ClapTextAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: @@ -1391,8 +1388,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -1414,12 +1411,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] @@ -1714,6 +1710,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1798,6 +1795,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 12a7fdf4ef7f..84fa79ef54f2 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -249,10 +249,9 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -273,32 +272,9 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + current_states = key_value_states if is_cross_attention else hidden_states + key_states = self.key(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.value(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -320,7 +296,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, None class Data2VecAudioFeedForward(nn.Module): diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 09caf8ccbc63..b208a0644e50 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -238,7 +238,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -392,7 +391,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -402,7 +400,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position=cache_position, @@ -449,7 +446,7 @@ def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ElectraAttention(config, layer_idx) + self.attention = ElectraAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: @@ -472,8 +469,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -495,12 +492,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] diff --git a/src/transformers/models/flaubert/modeling_flaubert.py b/src/transformers/models/flaubert/modeling_flaubert.py index 8eb2fc4b1c2a..0b6a6147280c 100644 --- a/src/transformers/models/flaubert/modeling_flaubert.py +++ b/src/transformers/models/flaubert/modeling_flaubert.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import gelu, get_activation +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, @@ -87,6 +88,7 @@ def __init__(self, n_heads, dim, config): self.layer_id = next(MultiHeadAttention.NEW_ID) self.dim = dim self.n_heads = n_heads + self.head_dim = dim // n_heads self.dropout = config.attention_dropout assert self.dim % self.n_heads == 0 @@ -111,50 +113,57 @@ def prune_heads(self, heads): self.dim = attention_head_size * self.n_heads self.pruned_heads = self.pruned_heads.union(heads) - def forward(self, input, mask, kv=None, cache=None, head_mask=None, output_attentions=False): + def forward( + self, + input, + mask, + kv=None, + cache=None, + head_mask=None, + output_attentions=False, + cache_position=None, + ): """ Self-attention (if kv is None) or attention over source sentence (provided by kv). """ # Input is (bs, qlen, dim) # Mask is (bs, klen) (non-causal) or (bs, klen, klen) bs, qlen, dim = input.size() - if kv is None: - klen = qlen if cache is None else cache["slen"] + qlen - else: - klen = kv.size(1) - # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' - n_heads = self.n_heads - dim_per_head = self.dim // n_heads - mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen) - - def shape(x): - """projection""" - return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2) - - def unshape(x): - """compute context""" - return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) - - q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head) - if kv is None: - k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head) - v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head) - elif cache is None or self.layer_id not in cache: - k = v = kv - k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head) - v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head) + is_cross_attention = kv is not None + mask_reshape = (bs, 1, qlen, -1) if mask.dim() == 3 else (bs, 1, 1, -1) + q = self.q_lin(input).view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2) if cache is not None: - if self.layer_id in cache: - if kv is None: - k_, v_ = cache[self.layer_id] - k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head) - v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head) + if isinstance(cache, EncoderDecoderCache): + is_updated = cache.is_updated.get(self.layer_id) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = cache.cross_attention_cache else: - k, v = cache[self.layer_id] - cache[self.layer_id] = (k, v) + curr_past_key_value = cache.self_attention_cache + else: + curr_past_key_value = cache - q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head) + current_states = kv if is_cross_attention else input + if is_cross_attention and cache is not None and is_updated: + # reuse k,v, cross_attentions + k = curr_past_key_value.key_cache[self.layer_id] + v = curr_past_key_value.value_cache[self.layer_id] + else: + k = self.k_lin(current_states) + v = self.v_lin(current_states) + k = k.view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2) + v = v.view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2) + + if cache is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + k, v = curr_past_key_value.update(k, v, self.layer_id, {"cache_position": cache_position}) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + cache.is_updated[self.layer_id] = True + + q = q / math.sqrt(self.head_dim) # (bs, n_heads, qlen, head_dim) scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen) mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen) scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen) @@ -166,8 +175,8 @@ def unshape(x): if head_mask is not None: weights = weights * head_mask - context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) - context = unshape(context) # (bs, qlen, dim) + context = torch.matmul(weights, v) # (bs, n_heads, qlen, head_dim) + context = context.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.head_dim) outputs = (self.out_lin(context),) if output_attentions: @@ -813,6 +822,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutput]: r""" lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -847,6 +857,9 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device + if not isinstance(cache, Cache): + cache = EncoderDecoderCache.from_legacy_cache(cache) + if lengths is None: if input_ids is not None: lengths = (input_ids != self.pad_index).sum(dim=1).long() @@ -892,7 +905,7 @@ def forward( # do not recompute cached elements if cache is not None and input_ids is not None: - _slen = slen - cache["slen"] + _slen = slen - cache.get_seq_length() input_ids = input_ids[:, -_slen:] position_ids = position_ids[:, -_slen:] if langs is not None: @@ -934,6 +947,7 @@ def forward( cache=cache, head_mask=head_mask[i], output_attentions=output_attentions, + cache_position=cache_position, ) attn = attn_outputs[0] if output_attentions: @@ -950,13 +964,6 @@ def forward( attn = nn.functional.dropout(attn, p=self.dropout, training=self.training) tensor = tensor + attn - # encoder attention (for decoder only) - # if self.is_decoder and src_enc is not None: - # attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache) - # attn = nn.functional.dropout(attn, p=self.dropout, training=self.training) - # tensor = tensor + attn - # tensor = self.layer_norm15[i](tensor) - # FFN if not self.pre_norm: tensor = tensor + self.ffns[i](tensor) @@ -971,13 +978,6 @@ def forward( if output_hidden_states: hidden_states = hidden_states + (tensor,) - # update cache length - if cache is not None: - cache["slen"] += tensor.size(1) - - # move back sequence length to dimension 0 - # tensor = tensor.transpose(0, 1) - if not return_dict: return tuple(v for v in [tensor, hidden_states, attentions] if v is not None) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 115345407e69..1a89ff92567d 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -309,10 +309,9 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -333,42 +332,9 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + current_states = key_value_states if is_cross_attention else hidden_states + key_states = self.key(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.value(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -390,7 +356,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, None class HubertFeedForward(nn.Module): diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 0dd38c8af060..e8563b766666 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -160,7 +160,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -314,7 +313,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -324,7 +322,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position=cache_position, @@ -371,7 +368,7 @@ def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = LayoutLMAttention(config, layer_idx) + self.attention = LayoutLMAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: @@ -394,8 +391,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -417,12 +414,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py index 9884b6d7e9e2..0e483ac146c2 100644 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ b/src/transformers/models/marian/modeling_tf_marian.py @@ -1523,7 +1523,7 @@ def prepare_inputs_for_generation( if decoder_attention_mask is not None: # xla decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] elif past_key_values is not None: # no xla + past_key_values - decoder_position_ids = past_key_values[0][0].shape[2] + decoder_position_ids = past_key_values.get_seq_length() else: # no xla + no past_key_values decoder_position_ids = tf.range(decoder_input_ids.shape[1]) diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index b01e986dc5b2..086764ed297f 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -366,7 +366,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -505,7 +504,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -515,7 +513,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position=cache_position, @@ -531,7 +528,7 @@ def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = MarkupLMAttention(config, layer_idx) + self.attention = MarkupLMAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: @@ -554,8 +551,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -577,12 +574,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index 16c53caa3f23..878452fc41ee 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -1539,7 +1539,7 @@ def prepare_inputs_for_generation( if decoder_attention_mask is not None: # xla decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] elif past_key_values is not None: # no xla + past_key_values - decoder_position_ids = past_key_values[0][0].shape[2] + decoder_position_ids = past_key_values.get_seq_length() else: # no xla + no past_key_values decoder_position_ids = tf.range(decoder_input_ids.shape[1]) diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 7d5a25ec67a2..2bab74a57937 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -216,7 +216,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -329,10 +328,10 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch. # Based transformers.models.bert.modeling_bert.BertAttention. Added LayerNorm. class MegatronBertAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.self = MegatronBertSelfAttention(config) + self.self = MegatronBertSelfAttention(config, layer_idx=layer_idx) self.output = MegatronBertSelfOutput(config) self.pruned_heads = set() @@ -360,19 +359,19 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: ln_outputs = self.ln(hidden_states) self_outputs = self.self( ln_outputs, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -410,17 +409,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Based on transformers.models.bert.modeling_bert.BertLayer. Added LayerNorm. class MegatronBertLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = MegatronBertAttention(config) + self.attention = MegatronBertAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise TypeError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = MegatronBertAttention(config) + self.crossattention = MegatronBertAttention(config, layer_idx=layer_idx) self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.intermediate = MegatronBertIntermediate(config) self.output = MegatronBertOutput(config) @@ -432,28 +431,27 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise AttributeError( @@ -461,24 +459,18 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) @@ -486,7 +478,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (past_key_value,) return outputs @@ -501,7 +493,7 @@ class MegatronBertEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.layer = nn.ModuleList([MegatronBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([MegatronBertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) # The final layer norm. We removed the 1st LN, moved LN to each hidden layer and this one # is simply the final LN (Transformer's BERT has it attached to each hidden layer). @@ -520,6 +512,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: if self.gradient_checkpointing and self.training: if use_cache: @@ -527,17 +520,26 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + next_decoder_cache = None - next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -547,8 +549,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) else: layer_outputs = layer_module( @@ -557,8 +560,9 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) # Because we moved the layer-norm at the end of the hidden layer, we have non-normali- @@ -566,7 +570,7 @@ def forward( hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[-1],) + next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -578,12 +582,16 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + next_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -592,7 +600,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -809,6 +817,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -887,6 +896,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None @@ -1056,6 +1066,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" @@ -1096,6 +1107,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = outputs[0] diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index d539db21f0fa..95895139ee51 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -182,7 +182,6 @@ def eager_attention_forward( return attn_output, attn_weights -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->Musicgen class MusicgenAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 12258e5a3842..c0b7f47b97d5 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -198,7 +198,7 @@ def eager_attention_forward( return attn_output, attn_weights -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->MusicgenMelody +# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->MusicgenMelody class MusicgenMelodyAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 6a67973863d4..788f38fa47b5 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -503,7 +503,7 @@ def eager_attention_forward( return attn_output, attn_weights -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->NllbMoe,key_value_states->encoder_hidden_states +# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->NllbMoe,key_value_states->encoder_hidden_states class NllbMoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index 8f00e8900928..cbee0d58a25d 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -307,10 +307,9 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -331,42 +330,9 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + current_states = key_value_states if is_cross_attention else hidden_states + key_states = self.key(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.value(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -388,7 +354,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, None class PatchMixerBlock(nn.Module): diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index b85e8a66b254..484f1bb8d210 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -104,10 +104,9 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -128,42 +127,9 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + current_states = key_value_states if is_cross_attention else hidden_states + key_states = self.key(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.value(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -185,7 +151,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, None class PatchTSTBatchNorm(nn.Module): diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py index 15176c92b01d..5d734db0a2ce 100644 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -1538,7 +1538,7 @@ def prepare_inputs_for_generation( if decoder_attention_mask is not None: # xla decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] elif past_key_values is not None: # no xla + past_key_values - decoder_position_ids = past_key_values[0][0].shape[2] + decoder_position_ids = past_key_values.get_seq_length() else: # no xla + no past_key_values decoder_position_ids = tf.range(decoder_input_ids.shape[1]) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 371af3e2a8dd..a9549d7c9921 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -1922,10 +1922,8 @@ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), ) return reordered_past diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 330cd99a7b4e..1d4513b6d633 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -302,10 +302,9 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -326,42 +325,9 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + current_states = key_value_states if is_cross_attention else hidden_states + key_states = self.key(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.value(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -383,7 +349,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, None class SEWFeedForward(nn.Module): diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 3c86e4b5fa9b..37aabc51b57d 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -205,7 +205,7 @@ def eager_attention_forward( return attn_output, attn_weights -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->Speech2Text +# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->Speech2Text class Speech2TextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index f962446d5eb4..812750a98f47 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -132,7 +132,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -286,7 +285,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -296,7 +294,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position=cache_position, @@ -343,7 +340,7 @@ def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = SplinterAttention(config, layer_idx) + self.attention = SplinterAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: @@ -366,8 +363,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -389,12 +386,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] diff --git a/src/transformers/models/superglue/modeling_superglue.py b/src/transformers/models/superglue/modeling_superglue.py index 44828c6bdec5..2e9dbdf49a00 100644 --- a/src/transformers/models/superglue/modeling_superglue.py +++ b/src/transformers/models/superglue/modeling_superglue.py @@ -272,7 +272,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -421,7 +420,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -431,7 +429,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position=cache_position, diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index d8b76269ad68..2bc23d414fbb 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -444,11 +444,11 @@ def forward( ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, - output_attentions=output_attentions, + attention_mask, + head_mask, + encoder_hidden_states, + past_key_value, + output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 4fdce328e9e6..c59504dba40f 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -348,10 +348,9 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -372,42 +371,9 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + current_states = key_value_states if is_cross_attention else hidden_states + key_states = self.key(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.value(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -429,7 +395,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, None class UniSpeechFeedForward(nn.Module): diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 50ee4c198d25..eb4442c3d3dc 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -351,10 +351,9 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -375,42 +374,9 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + current_states = key_value_states if is_cross_attention else hidden_states + key_states = self.key(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.value(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -432,7 +398,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, None class UniSpeechSatFeedForward(nn.Module): diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index ae3510f175e0..a808c588783d 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -540,10 +540,9 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -564,42 +563,9 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + current_states = key_value_states if is_cross_attention else hidden_states + key_states = self.key(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.value(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -621,7 +587,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, None class Wav2Vec2FeedForward(nn.Module): diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 2131376f2333..9814909d1c04 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -177,7 +177,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -289,7 +288,6 @@ def forward( attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -308,7 +306,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position, @@ -323,7 +320,6 @@ def forward( is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states - attention_mask = encoder_attention_mask if is_cross_attention else attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -446,7 +442,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -456,7 +451,6 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, cache_position=cache_position, @@ -503,7 +497,7 @@ def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = XLMRobertaAttention(config, layer_idx) + self.attention = XLMRobertaAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: @@ -526,8 +520,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -549,12 +543,11 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = cross_attention_outputs[0] From bcf0cc7c43e7c0fd041c8498874b54e92cf44e82 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 10 Jun 2025 13:04:44 +0200 Subject: [PATCH 13/58] delete `_reorder_cache` --- src/transformers/models/bark/modeling_bark.py | 15 ----- .../modeling_bert_generation.py | 8 --- .../models/big_bird/modeling_big_bird.py | 9 --- .../modeling_bigbird_pegasus.py | 9 --- .../models/biogpt/modeling_biogpt.py | 9 --- .../models/biogpt/modular_biogpt.py | 9 --- .../models/blenderbot/modeling_blenderbot.py | 20 ------- .../modeling_blenderbot_small.py | 20 ------- .../models/blip/modeling_blip_text.py | 8 --- .../models/bloom/modeling_bloom.py | 23 ------- .../models/camembert/modeling_camembert.py | 8 --- src/transformers/models/clvp/modeling_clvp.py | 14 ----- .../models/codegen/modeling_codegen.py | 14 ----- .../models/cpmant/modeling_cpmant.py | 7 --- src/transformers/models/ctrl/modeling_ctrl.py | 14 ----- .../models/data2vec/modeling_data2vec_text.py | 8 --- .../models/electra/modeling_electra.py | 9 --- .../modeling_encoder_decoder.py | 4 -- .../models/ernie/modeling_ernie.py | 9 --- .../models/falcon/modeling_falcon.py | 24 -------- src/transformers/models/fuyu/modeling_fuyu.py | 9 --- src/transformers/models/git/modeling_git.py | 8 --- src/transformers/models/gpt2/modeling_gpt2.py | 14 ----- .../models/gpt_neo/modeling_gpt_neo.py | 14 ----- .../modeling_gpt_neox_japanese.py | 9 --- src/transformers/models/gptj/modeling_gptj.py | 14 ----- .../models/granitemoe/modeling_granitemoe.py | 9 --- .../modeling_granitemoehybrid.py | 9 --- .../modeling_granitemoeshared.py | 9 --- .../models/idefics/modeling_idefics.py | 7 --- .../models/idefics2/modeling_idefics2.py | 10 ---- .../models/imagegpt/modeling_imagegpt.py | 14 ----- .../models/kosmos2/modeling_kosmos2.py | 10 ---- src/transformers/models/led/modeling_led.py | 11 ---- .../models/longt5/modeling_longt5.py | 24 -------- .../models/m2m_100/modeling_m2m_100.py | 9 --- .../models/marian/modeling_marian.py | 20 ------- .../models/markuplm/modeling_markuplm.py | 9 --- .../models/mbart/modeling_mbart.py | 20 ------- .../megatron_bert/modeling_megatron_bert.py | 8 --- .../models/moshi/modeling_moshi.py | 14 ----- src/transformers/models/mpt/modeling_mpt.py | 23 ------- src/transformers/models/mt5/modeling_mt5.py | 31 ---------- src/transformers/models/mvp/modeling_mvp.py | 9 --- .../models/nllb_moe/modeling_nllb_moe.py | 21 ++----- src/transformers/models/opt/modeling_opt.py | 9 --- .../models/pegasus/modeling_pegasus.py | 20 ------- .../models/pegasus_x/modeling_pegasus_x.py | 11 ---- .../models/pix2struct/modeling_pix2struct.py | 31 ---------- .../models/plbart/modeling_plbart.py | 20 ------- .../models/plbart/modular_plbart.py | 11 ---- .../models/pop2piano/modeling_pop2piano.py | 30 ---------- .../models/prophetnet/modeling_prophetnet.py | 20 ------- .../models/rembert/modeling_rembert.py | 9 --- .../models/roberta/modeling_roberta.py | 8 --- .../modeling_roberta_prelayernorm.py | 8 --- .../models/roc_bert/modeling_roc_bert.py | 9 --- .../models/roformer/modeling_roformer.py | 9 --- .../seamless_m4t/modeling_seamless_m4t.py | 60 ------------------- .../modeling_seamless_m4t_v2.py | 54 ----------------- .../modeling_speech_encoder_decoder.py | 4 -- .../speech_to_text/modeling_speech_to_text.py | 9 --- .../models/speecht5/modeling_speecht5.py | 9 --- .../modeling_switch_transformers.py | 32 ---------- src/transformers/models/t5/modeling_t5.py | 30 ---------- .../models/trocr/modeling_trocr.py | 9 --- src/transformers/models/udop/modeling_udop.py | 31 ---------- src/transformers/models/umt5/modeling_umt5.py | 9 --- .../modeling_vision_encoder_decoder.py | 4 -- .../models/whisper/modeling_whisper.py | 9 --- src/transformers/models/xglm/modeling_xglm.py | 9 --- .../xlm_roberta/modeling_xlm_roberta.py | 8 --- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 8 --- src/transformers/models/xmod/modeling_xmod.py | 9 --- 74 files changed, 5 insertions(+), 1060 deletions(-) diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index cc342b647201..5b2dcffa05ac 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -659,21 +659,6 @@ def forward( attentions=all_self_attentions, ) - @staticmethod - def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - # Necessary for beam_search - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 43838ce0dae6..62c18ce70118 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -904,14 +904,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = [ "BertGenerationDecoder", diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 46ffa4f76e2b..6f5b8b8892ac 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -2493,15 +2493,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - class BigBirdClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 7e943aaecb32..df30b5f7fa68 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -3070,15 +3070,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = [ "BigBirdPegasusForCausalLM", diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index f12eeac69730..2261684e3e2f 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -795,15 +795,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class BioGptForTokenClassification(BioGptPreTrainedModel): diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index 78d6da134b80..a498ed15ee0e 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -622,15 +622,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class BioGptForTokenClassification(BioGptPreTrainedModel): diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 4c001a354463..97df16a0acfd 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -1507,17 +1507,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Blenderbot class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel): @@ -1658,15 +1647,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = [ "BlenderbotForCausalLM", diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 49cff8f620ef..7115b9d3eb10 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -1459,17 +1459,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->BlenderbotSmall class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel): @@ -1610,15 +1599,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = [ "BlenderbotSmallForCausalLM", diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 4f399aeca48f..beb19d7366de 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -984,13 +984,5 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti "is_decoder": True, } - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["BlipTextModel", "BlipTextLMHeadModel", "BlipTextPreTrainedModel"] diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index ce67dd208356..f93ee1e60d6d 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -956,29 +956,6 @@ def forward( attentions=transformer_outputs.attentions, ) - def _reorder_cache( - self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - # Get a copy of `beam_idx` on all the devices where we need those indices. - device_to_beam_idx = { - past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past - } - reordered_past = tuple( - ( - layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), - layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), - ) - for layer_past in past - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 0843f1fdbeac..9c81e591c8c2 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -1581,14 +1581,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - # Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index a2fc7ea564a2..b596abeb56ad 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -1453,20 +1453,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index a1bef381ce70..8fca177b7584 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -716,19 +716,5 @@ def forward( attentions=transformer_outputs.attentions, ) - @staticmethod - def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or - [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - __all__ = ["CodeGenForCausalLM", "CodeGenModel", "CodeGenPreTrainedModel"] diff --git a/src/transformers/models/cpmant/modeling_cpmant.py b/src/transformers/models/cpmant/modeling_cpmant.py index 4f11bb49654b..e6ac65decc43 100755 --- a/src/transformers/models/cpmant/modeling_cpmant.py +++ b/src/transformers/models/cpmant/modeling_cpmant.py @@ -810,12 +810,5 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - def _reorder_cache(self, past_key_values, beam_idx): - past_key_values = [list(each) if each is not None else each for each in past_key_values] - for key_value_layer in past_key_values: - key_value_layer[0] = key_value_layer[0][beam_idx] - key_value_layer[1] = key_value_layer[1][beam_idx] - return past_key_values - __all__ = ["CpmAntForCausalLM", "CpmAntModel", "CpmAntPreTrainedModel"] diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index 78bf0f8c02af..f381add072e4 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -588,20 +588,6 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cac return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache} - @staticmethod - def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 9a2a6cc37dd9..93c06646f334 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -893,14 +893,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel): diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index b208a0644e50..dbf2c5a815bf 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -1612,15 +1612,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM._reorder_cache - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = [ "ElectraForCausalLM", diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 8aaced73980f..d5df1faa4fea 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -607,9 +607,5 @@ def resize_token_embeddings(self, *args, **kwargs): " model.decoder.resize_token_embeddings(...))" ) - def _reorder_cache(self, past_key_values, beam_idx): - # apply decoder cache reordering here - return self.decoder._reorder_cache(past_key_values, beam_idx) - __all__ = ["EncoderDecoderModel"] diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 4fd701b4b64e..b9d830dd57d2 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -1099,15 +1099,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel._reorder_cache - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class ErnieForMaskedLM(ErniePreTrainedModel): diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 2f54b8fcee27..ee013faf9b4d 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1131,30 +1131,6 @@ def forward( attentions=transformer_outputs.attentions, ) - def _reorder_cache( - self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - - # Get a copy of `beam_idx` on all the devices where we need those indices. - device_to_beam_idx = { - past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past - } - reordered_past = tuple( - ( - layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), - layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), - ) - for layer_past in past - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index c5e998114fb0..24ca04a57916 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -384,14 +384,5 @@ def prepare_inputs_for_generation( return model_inputs - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["FuyuForCausalLM", "FuyuPreTrainedModel", "FuyuModel"] diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 7e1399f9b479..2282b928075f 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -1508,13 +1508,5 @@ def prepare_inputs_for_generation( "use_cache": use_cache, } - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["GitForCausalLM", "GitModel", "GitPreTrainedModel", "GitVisionModel"] diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 3e66ede9f1dc..d4f6732db6a7 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1444,20 +1444,6 @@ def forward( attentions=transformer_outputs.attentions, ) - @staticmethod - def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 3937d106e42c..f37b503e8e54 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -918,20 +918,6 @@ def forward( attentions=transformer_outputs.attentions, ) - @staticmethod - def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or - [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index f4a073cc4a7f..9d15ef357a67 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -773,15 +773,6 @@ def forward( attentions=outputs.attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - __all__ = [ "GPTNeoXJapaneseForCausalLM", diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 093daaef193f..5434c75ade07 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -1056,20 +1056,6 @@ def forward( attentions=transformer_outputs.attentions, ) - @staticmethod - def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or - [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index a3a314a6abb5..e6833677a776 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -1036,14 +1036,5 @@ def forward( router_logits=outputs.router_logits, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["GraniteMoeForCausalLM", "GraniteMoeModel", "GraniteMoePreTrainedModel"] diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index d6ff36bf3244..c7fac08e57c1 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -1737,15 +1737,6 @@ def forward( router_logits=outputs.router_logits, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - def prepare_inputs_for_generation( self, input_ids, diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index dc429aa55bce..272618d1fe3c 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -1064,14 +1064,5 @@ def forward( router_logits=outputs.router_logits, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["GraniteMoeSharedForCausalLM", "GraniteMoeSharedModel", "GraniteMoeSharedPreTrainedModel"] diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 9aabc686795c..6b6fc0df056b 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1672,12 +1672,5 @@ def _update_model_kwargs_for_generation( model_kwargs["image_hidden_states"] = outputs.image_hidden_states return model_kwargs - @staticmethod - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) - return reordered_past - __all__ = ["IdeficsForVisionText2Text", "IdeficsModel", "IdeficsPreTrainedModel"] diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index f7008ad33e83..1132276ef1ac 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1351,15 +1351,5 @@ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_ model_kwargs["image_hidden_states"] = outputs.image_hidden_states return model_kwargs - @staticmethod - # Copied from transformers.models.opt.modeling_opt.OPTForCausalLM._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["Idefics2ForConditionalGeneration", "Idefics2PreTrainedModel", "Idefics2Model"] diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index aa8201489531..63b5698e95d5 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -972,20 +972,6 @@ def forward( cross_attentions=transformer_outputs.cross_attentions, ) - @staticmethod - def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index de9ad2c6cdf3..2ca5642e5850 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -1572,16 +1572,6 @@ def prepare_inputs_for_generation( return model_inputs - @staticmethod - # Copied from transformers.models.umt5.modeling_umt5.UMT5ForConditionalGeneration._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - class Kosmos2ImageToTextProjection(nn.Module): """The layer that transforms the image model's output to part of the text model's input (namely, image features)""" diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 0399b797409a..3bd43cc0e6d6 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -2384,17 +2384,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 62e696e8111d..1e64e1b85ada 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -2150,30 +2150,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) - def _reorder_cache(self, past_key_values, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past_key_values is None: - logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") - return past_key_values - - reordered_decoder_past = () - for layer_past_states in past_key_values: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), - ) - - assert reordered_layer_past_states[0].shape == layer_past_states[0].shape - assert len(reordered_layer_past_states) == len(layer_past_states) - - reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) - return reordered_decoder_past - @auto_docstring class LongT5EncoderModel(LongT5PreTrainedModel): diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index f3488672721a..7e8dcd4f47ad 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -1487,14 +1487,5 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["M2M100ForConditionalGeneration", "M2M100Model", "M2M100PreTrainedModel"] diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index a604820f2cce..9d613f6ccccc 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1614,17 +1614,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Marian class MarianDecoderWrapper(MarianPreTrainedModel): @@ -1765,14 +1754,5 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["MarianForCausalLM", "MarianModel", "MarianMTModel", "MarianPreTrainedModel"] diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 086764ed297f..53315d9e422b 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -882,15 +882,6 @@ def forward( cross_attentions=encoder_outputs.cross_attentions, ) - # Copied from transformers.models.bert.modeling_bert.BertModel._reorder_cache - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel): diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 4f3253eeb442..25a73249853a 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1531,17 +1531,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -1968,15 +1957,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = [ "MBartForCausalLM", diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 2bab74a57937..1af60ffc9b20 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -1135,14 +1135,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class MegatronBertForMaskedLM(MegatronBertPreTrainedModel): diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 7e71eb2ce200..fdaa5246064b 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -2594,19 +2594,5 @@ def _check_and_maybe_initialize_inputs( return input_ids, user_audio_codes, moshi_audio_codes, concat_unconditional_inputs - @staticmethod - def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - __all__ = ["MoshiForCausalLM", "MoshiForConditionalGeneration", "MoshiModel", "MoshiPreTrainedModel"] diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index fe2d5d142299..8c560e4b4151 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -541,29 +541,6 @@ def forward( attentions=transformer_outputs.attentions, ) - def _reorder_cache( - self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - # Get a copy of `beam_idx` on all the devices where we need those indices. - device_to_beam_idx = { - past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past - } - reordered_past = tuple( - ( - layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), - layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), - ) - for layer_past in past - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 6b488b66d22e..36c4b9f3ba7f 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1894,37 +1894,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._reorder_cache - def _reorder_cache(self, past_key_values, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past_key_values is None: - logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") - return past_key_values - - reordered_decoder_past = () - for layer_past_states in past_key_values: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), - ) - - if reordered_layer_past_states[0].shape != layer_past_states[0].shape: - raise ValueError( - f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" - ) - if len(reordered_layer_past_states) != len(layer_past_states): - raise ValueError( - f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" - ) - - reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) - return reordered_decoder_past - @auto_docstring class MT5EncoderModel(MT5PreTrainedModel): diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 135d844827c3..4d00617ff6c8 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -1842,15 +1842,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = [ "MvpForCausalLM", diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 788f38fa47b5..87efa778154b 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -516,7 +516,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[NllbMoeConfig] = None, - layer_idx: Optional[int] = None, + layer_idx: Optional[bool] = None, ): super().__init__() self.embed_dim = embed_dim @@ -547,7 +547,7 @@ def forward( past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn @@ -573,7 +573,7 @@ def forward( if isinstance(past_key_value, EncoderDecoderCache): is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: - # after the first generated id, we can subsequently re-use all key/value_states from cache + # after the first generated id, we can subsequently re-use all key/value_layer from cache curr_past_key_value = past_key_value.cross_attention_cache else: curr_past_key_value = past_key_value.self_attention_cache @@ -586,10 +586,8 @@ def forward( key_states = curr_past_key_value.key_cache[self.layer_idx] value_states = curr_past_key_value.value_cache[self.layer_idx] else: - key_states = self.k_proj(current_states) - value_states = self.v_proj(current_states) - key_states = key_states.view(*kv_input_shape).transpose(1, 2) - value_states = value_states.view(*kv_input_shape).transpose(1, 2) + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation @@ -1807,15 +1805,6 @@ def _unpack_router_logits(self, router_outputs): total_expert_indexes = torch.stack(total_expert_indexes, dim=1) if len(total_expert_indexes) > 0 else None return total_router_logits, total_expert_indexes - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = [ "NllbMoeForConditionalGeneration", diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index eef54b02ec03..7413c92630c7 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -910,15 +910,6 @@ def forward( attentions=outputs.attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 303ae89fd02b..9ade40688847 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -1562,17 +1562,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Pegasus class PegasusDecoderWrapper(PegasusPreTrainedModel): @@ -1735,14 +1724,5 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"] diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index bf94379ccae7..549c27cbab16 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1750,17 +1750,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->PegasusX class PegasusXDecoderWrapper(PegasusXPreTrainedModel): diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index f9a5b00218df..413cc77e4461 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1050,37 +1050,6 @@ def __init__(self, config): self.post_init() self.gradient_checkpointing = False - # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._reorder_cache - def _reorder_cache(self, past_key_values, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past_key_values is None: - logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") - return past_key_values - - reordered_decoder_past = () - for layer_past_states in past_key_values: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), - ) - - if reordered_layer_past_states[0].shape != layer_past_states[0].shape: - raise ValueError( - f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" - ) - if len(reordered_layer_past_states) != len(layer_past_states): - raise ValueError( - f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" - ) - - reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) - return reordered_decoder_past - def get_input_embeddings(self): return self.embed_tokens diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 695a0ed458f7..c5e9dc95731a 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1465,17 +1465,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - class PLBartClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" @@ -1794,15 +1783,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = [ "PLBartForCausalLM", diff --git a/src/transformers/models/plbart/modular_plbart.py b/src/transformers/models/plbart/modular_plbart.py index 1394e87f5607..4fc41a03ac8d 100644 --- a/src/transformers/models/plbart/modular_plbart.py +++ b/src/transformers/models/plbart/modular_plbart.py @@ -596,17 +596,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - class PLBartClassificationHead(BartClassificationHead): pass diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index c63b4df774b7..1d309e48d802 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -1373,35 +1373,5 @@ def generate( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) - def _reorder_cache(self, past_key_values, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past_key_values is None: - logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") - return past_key_values - - reordered_decoder_past = () - for layer_past_states in past_key_values: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), - ) - - if reordered_layer_past_states[0].shape != layer_past_states[0].shape: - raise ValueError( - f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" - ) - if len(reordered_layer_past_states) != len(layer_past_states): - raise ValueError( - f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" - ) - - reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) - return reordered_decoder_past - __all__ = ["Pop2PianoForConditionalGeneration", "Pop2PianoPreTrainedModel"] diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index a9549d7c9921..4a531a04d3d8 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -1917,16 +1917,6 @@ def _compute_loss(self, logits, labels, ignore_index=-100): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) - @staticmethod - # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - def get_encoder(self): return self.prophetnet.encoder @@ -2149,16 +2139,6 @@ def prepare_inputs_for_generation( "use_cache": use_cache, } - @staticmethod - # Copied from transformers.models.bart.modeling_bart.BartForCausalLM._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel): """ diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index fad1c6f3d52f..2f5c36276570 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -1033,15 +1033,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index d6c770595090..164ecc8a5813 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -1046,14 +1046,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class RobertaForMaskedLM(RobertaPreTrainedModel): diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 8aaf88b4c51a..771028d40be9 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -907,14 +907,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 86f23459f048..eb8044c912d7 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -1511,15 +1511,6 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, } - # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel._reorder_cache - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 0f148a3d3b6d..d9dc46a1f6fd 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -1158,15 +1158,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - class RoFormerClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 83a4a1faf465..e364d3a5a961 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -2153,16 +2153,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.t2u_pad_token_id, self.config.t2u_decoder_start_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - def _tie_weights(self) -> None: if getattr(self.config, "tie_word_embeddings", True): output_embeddings = self.get_output_embeddings() @@ -2772,16 +2762,6 @@ def generate( **kwargs, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -3049,16 +3029,6 @@ def generate( **kwargs, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -3388,16 +3358,6 @@ def generate( return waveform, waveform_lengths - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -3735,16 +3695,6 @@ def generate( return waveform, waveform_lengths - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -4185,16 +4135,6 @@ def generate( return waveform, waveform_lengths - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - __all__ = [ "SeamlessM4TForTextToSpeech", diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index fde845446e16..aae0eac16a87 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -3022,16 +3022,6 @@ def generate( **kwargs, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -3309,17 +3299,6 @@ def generate( **kwargs, ) - @staticmethod - # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToText._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -3689,17 +3668,6 @@ def generate( return waveform, waveform_lengths - @staticmethod - # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -4076,17 +4044,6 @@ def generate( return waveform, waveform_lengths - @staticmethod - # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToSpeech._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -4571,17 +4528,6 @@ def generate( return waveform, waveform_lengths - @staticmethod - # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TModel._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - __all__ = [ "SeamlessM4Tv2ForTextToSpeech", diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index db131a2cbfe9..aa284895474e 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -507,9 +507,5 @@ def resize_token_embeddings(self, *args, **kwargs): " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" ) - def _reorder_cache(self, past_key_values, beam_idx): - # apply decoder cache reordering here - return self.decoder._reorder_cache(past_key_values, beam_idx) - __all__ = ["SpeechEncoderDecoderModel"] diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 37aabc51b57d..ce1af294cf7f 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -1398,14 +1398,5 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["Speech2TextForConditionalGeneration", "Speech2TextModel", "Speech2TextPreTrainedModel"] diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 35a61203f88f..1efc0e718ed2 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -2334,15 +2334,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - def _generate_speech( model: SpeechT5PreTrainedModel, diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index d2db7781d2fc..d6dba63ee428 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1761,38 +1761,6 @@ def _unpack_router_logits(self, router_outputs): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) - def _reorder_cache(self, past_key_values, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past_key_values is None: - logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") - return past_key_values - - reordered_decoder_past = () - for layer_past_states in past_key_values: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), - ) - - if reordered_layer_past_states[0].shape != layer_past_states[0].shape: - raise ValueError( - "expected reordered_layer_past_states to have the same shape than layer_past_states, " - f"but got {reordered_layer_past_states[0].shape} and {layer_past_states[0].shape}" - ) - if len(reordered_layer_past_states) != len(layer_past_states): - raise ValueError( - "expected layer_past_states to have the same length as reordered_layer_past_states, " - f"but got {len(layer_past_states)} and {len(reordered_layer_past_states)}" - ) - - reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) - return reordered_decoder_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 466b725bce2a..b6dcfa9548b8 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1866,36 +1866,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) - def _reorder_cache(self, past_key_values, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past_key_values is None: - logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") - return past_key_values - - reordered_decoder_past = () - for layer_past_states in past_key_values: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), - ) - - if reordered_layer_past_states[0].shape != layer_past_states[0].shape: - raise ValueError( - f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" - ) - if len(reordered_layer_past_states) != len(layer_past_states): - raise ValueError( - f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" - ) - - reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) - return reordered_decoder_past - @auto_docstring class T5EncoderModel(T5PreTrainedModel): diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index ebcafbcd3fde..2f8165d057c7 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -888,14 +888,5 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["TrOCRForCausalLM", "TrOCRPreTrainedModel"] diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 478baddc505c..10dd7f9188ab 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -1900,37 +1900,6 @@ def forward( encoder_attentions=encoder_outputs.attentions, ) - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._reorder_cache - def _reorder_cache(self, past_key_values, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past_key_values is None: - logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") - return past_key_values - - reordered_decoder_past = () - for layer_past_states in past_key_values: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), - ) - - if reordered_layer_past_states[0].shape != layer_past_states[0].shape: - raise ValueError( - f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" - ) - if len(reordered_layer_past_states) != len(layer_past_states): - raise ValueError( - f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" - ) - - reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) - return reordered_decoder_past - @auto_docstring class UdopEncoderModel(UdopPreTrainedModel): diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index e45d63aba7db..f1c4c22a77d4 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -1419,15 +1419,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class UMT5EncoderModel(UMT5PreTrainedModel): diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index c82f4b63113e..95310c3ced53 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -597,9 +597,5 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - def _reorder_cache(self, past_key_values, beam_idx): - # apply decoder cache reordering here - return self.decoder._reorder_cache(past_key_values, beam_idx) - __all__ = ["VisionEncoderDecoderModel"] diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 7bb07a6c1c6a..97c679dbb4a6 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1567,15 +1567,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 8efed99abe1f..75d604d98aca 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -745,14 +745,5 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["XGLMForCausalLM", "XGLMModel", "XGLMPreTrainedModel"] diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 9814909d1c04..b117be20954c 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -1038,14 +1038,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring # Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index aaa0a803c018..1c49cfaad65d 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -1057,14 +1057,6 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti "past_key_values": past_key_values, } - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class XLMRobertaXLForMaskedLM(XLMRobertaXLPreTrainedModel): diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index df5a6487fd96..dc104f6e80bc 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -1007,15 +1007,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM._reorder_cache - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class XmodForMaskedLM(XmodPreTrainedModel): From 91d92f1db8a668d47368448a5cfdfa311a25f75d Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 10 Jun 2025 13:11:37 +0200 Subject: [PATCH 14/58] another fix copies --- src/transformers/models/altclip/modeling_altclip.py | 2 ++ .../models/blenderbot/modeling_tf_blenderbot.py | 2 +- .../blenderbot_small/modeling_tf_blenderbot_small.py | 2 +- .../models/bridgetower/modeling_bridgetower.py | 2 ++ src/transformers/models/marian/modeling_tf_marian.py | 2 +- src/transformers/models/mbart/modeling_tf_mbart.py | 2 +- src/transformers/models/musicgen/modeling_musicgen.py | 10 +++++----- .../models/musicgen_melody/modeling_musicgen_melody.py | 8 ++++---- src/transformers/models/pegasus/modeling_tf_pegasus.py | 2 +- 9 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index e7b389d4819c..07f435868c64 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -1201,6 +1201,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1285,6 +1286,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py index 23f817a03770..f3476cb925b6 100644 --- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py @@ -1525,7 +1525,7 @@ def prepare_inputs_for_generation( if decoder_attention_mask is not None: # xla decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] elif past_key_values is not None: # no xla + past_key_values - decoder_position_ids = past_key_values.get_seq_length() + decoder_position_ids = past_key_values[0][0].shape[2] else: # no xla + no past_key_values decoder_position_ids = tf.range(decoder_input_ids.shape[1]) diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py index 4c04c912b66d..4de98280836d 100644 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py @@ -1495,7 +1495,7 @@ def prepare_inputs_for_generation( if decoder_attention_mask is not None: # xla decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] elif past_key_values is not None: # no xla + past_key_values - decoder_position_ids = past_key_values.get_seq_length() + decoder_position_ids = past_key_values[0][0].shape[2] else: # no xla + no past_key_values decoder_position_ids = tf.range(decoder_input_ids.shape[1]) diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 993c94cc96cf..b5ecae200250 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -1080,6 +1080,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1164,6 +1165,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py index 0e483ac146c2..9884b6d7e9e2 100644 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ b/src/transformers/models/marian/modeling_tf_marian.py @@ -1523,7 +1523,7 @@ def prepare_inputs_for_generation( if decoder_attention_mask is not None: # xla decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] elif past_key_values is not None: # no xla + past_key_values - decoder_position_ids = past_key_values.get_seq_length() + decoder_position_ids = past_key_values[0][0].shape[2] else: # no xla + no past_key_values decoder_position_ids = tf.range(decoder_input_ids.shape[1]) diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index 878452fc41ee..16c53caa3f23 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -1539,7 +1539,7 @@ def prepare_inputs_for_generation( if decoder_attention_mask is not None: # xla decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] elif past_key_values is not None: # no xla + past_key_values - decoder_position_ids = past_key_values.get_seq_length() + decoder_position_ids = past_key_values[0][0].shape[2] else: # no xla + no past_key_values decoder_position_ids = tf.range(decoder_input_ids.shape[1]) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 95895139ee51..88fab56918aa 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -189,12 +189,12 @@ def __init__( self, embed_dim: int, num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - is_causal: bool = False, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + is_causal: Optional[bool] = False, config: Optional[MusicgenConfig] = None, - layer_idx: Optional[bool] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index c0b7f47b97d5..f4a3a7f167d6 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -206,10 +206,10 @@ def __init__( self, embed_dim: int, num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - is_causal: bool = False, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + is_causal: Optional[bool] = False, config: Optional[MusicgenMelodyConfig] = None, layer_idx: Optional[int] = None, ): diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py index 5d734db0a2ce..15176c92b01d 100644 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -1538,7 +1538,7 @@ def prepare_inputs_for_generation( if decoder_attention_mask is not None: # xla decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] elif past_key_values is not None: # no xla + past_key_values - decoder_position_ids = past_key_values.get_seq_length() + decoder_position_ids = past_key_values[0][0].shape[2] else: # no xla + no past_key_values decoder_position_ids = tf.range(decoder_input_ids.shape[1]) From edf5f6e01a20fb5f42444a02cc5fe7a44264f2fb Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 10 Jun 2025 13:37:54 +0200 Subject: [PATCH 15/58] fix typos and delete unnecessary test --- .../data2vec/modeling_data2vec_audio.py | 4 +-- .../models/hubert/modeling_hubert.py | 4 +-- .../models/nllb_moe/modeling_nllb_moe.py | 10 +++---- .../patchtsmixer/modeling_patchtsmixer.py | 4 +-- .../models/patchtst/modeling_patchtst.py | 4 +-- src/transformers/models/sew/modeling_sew.py | 4 +-- .../speech_to_text/modeling_speech_to_text.py | 12 ++++---- .../models/unispeech/modeling_unispeech.py | 4 +-- .../unispeech_sat/modeling_unispeech_sat.py | 4 +-- .../models/wav2vec2/modeling_wav2vec2.py | 4 +-- tests/utils/test_cache_utils.py | 30 ------------------- 11 files changed, 26 insertions(+), 58 deletions(-) diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 84fa79ef54f2..8960f9cb4829 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -273,8 +273,8 @@ def forward( query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) current_states = key_value_states if is_cross_attention else hidden_states - key_states = self.key(current_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.value(current_states).view(*kv_input_shape).transpose(1, 2) + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 1a89ff92567d..b6e792d6b4f6 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -333,8 +333,8 @@ def forward( query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) current_states = key_value_states if is_cross_attention else hidden_states - key_states = self.key(current_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.value(current_states).view(*kv_input_shape).transpose(1, 2) + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 87efa778154b..49fb6b50df28 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -511,12 +511,12 @@ def __init__( self, embed_dim: int, num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - is_causal: bool = False, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + is_causal: Optional[bool] = False, config: Optional[NllbMoeConfig] = None, - layer_idx: Optional[bool] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index cbee0d58a25d..343b6bb8d343 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -331,8 +331,8 @@ def forward( query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) current_states = key_value_states if is_cross_attention else hidden_states - key_states = self.key(current_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.value(current_states).view(*kv_input_shape).transpose(1, 2) + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index 484f1bb8d210..12daaaa7f9a1 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -128,8 +128,8 @@ def forward( query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) current_states = key_value_states if is_cross_attention else hidden_states - key_states = self.key(current_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.value(current_states).view(*kv_input_shape).transpose(1, 2) + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 1d4513b6d633..9b0868fbfa29 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -326,8 +326,8 @@ def forward( query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) current_states = key_value_states if is_cross_attention else hidden_states - key_states = self.key(current_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.value(current_states).view(*kv_input_shape).transpose(1, 2) + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index ce1af294cf7f..50bd95f55e8c 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -249,12 +249,12 @@ def forward( past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer @@ -275,7 +275,7 @@ def forward( if isinstance(past_key_value, EncoderDecoderCache): is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: - # after the first generated id, we can subsequently re-use all key/value_states from cache + # after the first generated id, we can subsequently re-use all key/value_layer from cache curr_past_key_value = past_key_value.cross_attention_cache else: curr_past_key_value = past_key_value.self_attention_cache @@ -288,10 +288,8 @@ def forward( key_states = curr_past_key_value.key_cache[self.layer_idx] value_states = curr_past_key_value.value_cache[self.layer_idx] else: - key_states = self.k_proj(current_states) - value_states = self.v_proj(current_states) - key_states = key_states.view(*kv_input_shape).transpose(1, 2) - value_states = value_states.view(*kv_input_shape).transpose(1, 2) + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index c59504dba40f..941d97ff4f8d 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -372,8 +372,8 @@ def forward( query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) current_states = key_value_states if is_cross_attention else hidden_states - key_states = self.key(current_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.value(current_states).view(*kv_input_shape).transpose(1, 2) + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index eb4442c3d3dc..41fa667aa650 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -375,8 +375,8 @@ def forward( query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) current_states = key_value_states if is_cross_attention else hidden_states - key_states = self.key(current_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.value(current_states).view(*kv_input_shape).transpose(1, 2) + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index a808c588783d..514fdf7adc09 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -564,8 +564,8 @@ def forward( query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) current_states = key_value_states if is_cross_attention else hidden_states - key_states = self.key(current_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.value(current_states).view(*kv_input_shape).transpose(1, 2) + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 9d435cb7ed11..832cebe3e8d6 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -121,36 +121,6 @@ def test_dynamic_cache_retrocompatibility(self): torch.allclose(to_legacy[layer_idx][key_value_idx], new_cache[layer_idx][key_value_idx]) ) - def test_reorder_cache_retrocompatibility(self): - """Tests that Cache.reorder_cache is retrocompatible with the legacy code path""" - legacy_reorder_fn = ClvpForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function - - legacy_cache = () - new_cache = DynamicCache() - - # Creates a new cache with 10 layers in both formats - for layer_idx in range(10): - new_key = torch.rand((4, 4, 8, 16)) - new_value = torch.rand((4, 4, 8, 16)) - new_cache.update(new_key, new_value, layer_idx) - legacy_cache += ((new_key, new_value),) - - # Let's create some dummy beam indices. From the shape above, it is equivalent to the case where num_beams=4 - # and batch_size=1 - beam_idx = torch.randint(low=0, high=4, size=(4,)) - - legacy_cache_reordered = legacy_reorder_fn(legacy_cache, beam_idx) - new_cache.reorder_cache(beam_idx) - - # Let's check that the results are the same - for layer_idx in range(10): - for key_value_idx in range(2): - self.assertTrue( - torch.allclose( - new_cache[layer_idx][key_value_idx], legacy_cache_reordered[layer_idx][key_value_idx] - ) - ) - def test_static_cache_mha_mqa_gqa(self): """ Tests that static cache works with multi-head attention (MHA), grouped query attention (GQA), and multi-query From b236e90aabd56248f6a112327cd6de34ee24966d Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 10 Jun 2025 13:49:07 +0200 Subject: [PATCH 16/58] fix rag generate, needs special cache reordering --- src/transformers/cache_utils.py | 13 +++++++++++ src/transformers/generation/utils.py | 26 ++++++++++++++------- src/transformers/models/rag/modeling_rag.py | 4 +++- tests/utils/test_cache_utils.py | 1 - 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index be6b18d2162f..1bc567efdebe 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1421,6 +1421,19 @@ def __len__(self): """ return len(self.self_attention_cache) + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield ( + self.self_attention_cache.key_cache[layer_idx], + self.self_attention_cache.value_cache[layer_idx], + self.cross_attention_cache.key_cache[layer_idx], + self.cross_attention_cache.value_cache[layer_idx], + ) + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor]]: """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format.""" legacy_cache = () diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 4ac06aed49d3..5fb3f6d67463 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1000,12 +1000,6 @@ def _update_model_kwargs_for_generation( model_kwargs["cache_position"] = torch.cat((past_positions, new_positions)) return model_kwargs - def _reorder_cache(self, past_key_values, beam_idx): - raise NotImplementedError( - f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to" - f" enable beam search for {self.__class__}" - ) - def _get_candidate_generator( self, generation_config: GenerationConfig, @@ -4133,9 +4127,13 @@ def _beam_search( # beam search as a whole (as opposed to individual beams, i.e. `stopping_criteria`) # pluck the cache from the beam indices that will be used in the next iteration + # NOTE: we need to check if `self._reorder_cache` for special models like RAG, RecurrentGemma etc. if model_kwargs.get("past_key_values", None) is not None: beam_idx = self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len]) - model_kwargs["past_key_values"].reorder_cache(beam_idx) + if hasattr(self, "_reorder_cache"): + model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) + else: + model_kwargs["past_key_values"].reorder_cache(beam_idx) cur_len = cur_len + 1 this_peer_finished = not self._beam_search_has_unfinished_sequences( @@ -4432,8 +4430,14 @@ def _group_beam_search( # (that way the memory peak does not include outputs.logits) del outputs + # NOTE: we need to check if `self._reorder_cache` for special models like RAG, RecurrentGemma etc. if model_kwargs.get("past_key_values", None) is not None: - model_kwargs["past_key_values"].reorder_cache(reordering_indices) + if hasattr(self, "_reorder_cache"): + model_kwargs["past_key_values"] = self._reorder_cache( + model_kwargs["past_key_values"], reordering_indices + ) + else: + model_kwargs["past_key_values"].reorder_cache(reordering_indices) # increase cur_len cur_len = cur_len + 1 @@ -4667,8 +4671,12 @@ def _constrained_beam_search( # (that way the memory peak does not include outputs.logits) del outputs + # NOTE: we need to check if `self._reorder_cache` for special models like RAG, RecurrentGemma etc. if model_kwargs.get("past_key_values", None) is not None: - model_kwargs["past_key_values"].reorder_cache(beam_idx) + if hasattr(self, "_reorder_cache"): + model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) + else: + model_kwargs["past_key_values"].reorder_cache(beam_idx) if return_dict_in_generate and output_scores: beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 809afd8866a8..fb4811a4c3d5 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -21,7 +21,7 @@ import torch from torch import nn -from ...cache_utils import Cache +from ...cache_utils import Cache, EncoderDecoderCache from ...configuration_utils import PretrainedConfig from ...generation import GenerationConfig, GenerationMixin, LogitsProcessorList, StoppingCriteriaList from ...modeling_outputs import ModelOutput @@ -1199,6 +1199,8 @@ def _reorder_stacked(hidden_states, new_order): tuple(_reorder_stacked(past_state, beam_idx.to(past_state.device)) for past_state in layer_past), ) + if isinstance(past_key_values, EncoderDecoderCache): + reordered_past = EncoderDecoderCache.from_legacy_cache(reordered_past) return reordered_past def marginalize(self, seq_logits, doc_scores, n_docs=None): diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 832cebe3e8d6..538713ad9389 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -45,7 +45,6 @@ AutoModelForCausalLM, AutoTokenizer, Cache, - ClvpForCausalLM, DynamicCache, Gemma2Config, GenerationConfig, From 1893f8a82105a2c6acc77af66cfe91cfc1f7c341 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 10 Jun 2025 14:04:38 +0200 Subject: [PATCH 17/58] fix tapas and superglue --- .../models/align/modeling_align.py | 10 +-- src/transformers/models/bert/modeling_bert.py | 10 +-- .../modeling_bert_generation.py | 10 +-- .../bridgetower/modeling_bridgetower.py | 10 +-- .../chinese_clip/modeling_chinese_clip.py | 10 +-- src/transformers/models/clap/modeling_clap.py | 10 +-- .../models/data2vec/modeling_data2vec_text.py | 10 +-- .../models/electra/modeling_electra.py | 10 +-- .../models/ernie/modeling_ernie.py | 10 +-- .../models/layoutlm/modeling_layoutlm.py | 10 +-- .../models/markuplm/modeling_markuplm.py | 10 +-- .../models/rembert/modeling_rembert.py | 10 +-- .../models/roberta/modeling_roberta.py | 10 +-- .../models/roc_bert/modeling_roc_bert.py | 10 +-- .../models/splinter/modeling_splinter.py | 10 +-- .../models/superglue/modeling_superglue.py | 70 ++++--------------- .../models/tapas/modeling_tapas.py | 11 ++- 17 files changed, 95 insertions(+), 136 deletions(-) diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 8e6a4ef50463..098685a6e1bb 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -788,11 +788,11 @@ def forward( ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 02917376ec72..24e12296bb64 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -494,11 +494,11 @@ def forward( ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 62c18ce70118..11d29f399cf3 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -236,11 +236,11 @@ def forward( ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index b5ecae200250..a4c2b1e4b9c7 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -593,11 +593,11 @@ def forward( ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 02651d26741c..bd304914530f 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -438,11 +438,11 @@ def forward( ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 56b66a418b0d..05861957a3b4 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -1316,11 +1316,11 @@ def forward( ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 93c06646f334..708043448f23 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -336,11 +336,11 @@ def forward( ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index dbf2c5a815bf..13704b06c020 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -397,11 +397,11 @@ def forward( ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index b9d830dd57d2..726f7067e698 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -322,11 +322,11 @@ def forward( ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index e8563b766666..20a20d132ab3 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -319,11 +319,11 @@ def forward( ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 53315d9e422b..e86258469ddb 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -510,11 +510,11 @@ def forward( ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 2f5c36276570..116373910cf9 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -360,11 +360,11 @@ def forward( ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 164ecc8a5813..e3861441c281 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -447,11 +447,11 @@ def forward( ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index eb8044c912d7..467f9372176d 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -449,11 +449,11 @@ def forward( ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 812750a98f47..7f679ccf7b3b 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -291,11 +291,11 @@ def forward( ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) diff --git a/src/transformers/models/superglue/modeling_superglue.py b/src/transformers/models/superglue/modeling_superglue.py index 2e9dbdf49a00..7a79603b3834 100644 --- a/src/transformers/models/superglue/modeling_superglue.py +++ b/src/transformers/models/superglue/modeling_superglue.py @@ -23,7 +23,6 @@ from transformers import PreTrainedModel from transformers.models.superglue.configuration_superglue import SuperGlueConfig -from ...cache_utils import Cache, EncoderDecoderCache from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging from ..auto import AutoModelForKeypointDetection @@ -232,9 +231,8 @@ def forward( return hidden_state, all_hidden_states -# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->SuperGlue class SuperGlueSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, layer_idx=None): + def __init__(self, config, position_embedding_type=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -259,7 +257,6 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder - self.layer_idx = layer_idx def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -272,60 +269,27 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, - cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) - if is_cross_attention: - # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache - else: - curr_past_key_value = past_key_value.self_attention_cache - else: - curr_past_key_value = past_key_value - current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: - # reuse k,v, cross_attentions - key_layer = curr_past_key_value.key_cache[self.layer_idx] - value_layer = curr_past_key_value.value_cache[self.layer_idx] - else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) - - if past_key_value is not None: - # save all key/value_layer to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not is_cross_attention else None - key_layer, value_layer = curr_past_key_value.update( - key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} - ) - # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls - if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + attention_mask = encoder_attention_mask if is_cross_attention else encoder_attention_mask - query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + query_layer = self.transpose_for_scores(self.query(hidden_states)) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if past_key_value is not None: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( - -1, 1 - ) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r @@ -365,7 +329,7 @@ def forward( outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) if self.is_decoder: - outputs = outputs + (past_key_value,) + outputs = outputs + (None,) return outputs @@ -384,14 +348,12 @@ def forward(self, hidden_states: torch.Tensor, *args) -> torch.Tensor: } -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->SuperGlue,BERT->SUPERGLUE class SuperGlueAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, layer_idx=None): + def __init__(self, config, position_embedding_type=None): super().__init__() self.self = SUPERGLUE_SELF_ATTENTION_CLASSES[config._attn_implementation]( config, position_embedding_type=position_embedding_type, - layer_idx=layer_idx, ) self.output = SuperGlueSelfOutput(config) self.pruned_heads = set() @@ -420,18 +382,16 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, - cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, - cache_position=cache_position, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 2bc23d414fbb..507deccfc2f5 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -316,7 +316,6 @@ def forward( attention_mask=None, head_mask=None, encoder_hidden_states=None, - encoder_attention_mask=None, past_key_value=None, output_attentions=False, cache_position=None, @@ -444,11 +443,11 @@ def forward( ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) From 46e50b5a213cfc0fd7d0b65bbe7dbf3110ed6eb2 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 10 Jun 2025 16:22:10 +0200 Subject: [PATCH 18/58] reformer create special cache --- src/transformers/generation/utils.py | 2 +- .../models/reformer/modeling_reformer.py | 226 ++++++++++++++---- .../models/superglue/modeling_superglue.py | 1 - 3 files changed, 174 insertions(+), 55 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5fb3f6d67463..a09f27040063 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1951,7 +1951,7 @@ def _supports_default_dynamic_cache(self) -> bool: """ return all( special_model_name not in self.__class__.__name__.lower() - for special_model_name in ["jamba", "zamba", "mamba", "bamba"] + for special_model_name in ["jamba", "zamba", "mamba", "bamba", "reformer"] ) def _prepare_cache_for_generation( diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 3422fbad2b25..3abad3258ffc 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -20,7 +20,7 @@ from dataclasses import dataclass from functools import reduce from operator import mul -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch @@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import DynamicCache from ...generation import GenerationMixin from ...modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel @@ -60,6 +61,117 @@ ) +class ReformerDynamicCache(DynamicCache): + """ + A dynamic cache that stores past buckets instead of key/values. + """ + + def __init__(self, _distributed_cache_data: Optional[Iterable] = None) -> None: + super().__init__() + self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + self.buckets_cache: List[torch.Tensor] = [] + self.states_cache: List[torch.Tensor] = [] + + if _distributed_cache_data is not None: + for buckets, states in _distributed_cache_data: + self.buckets_cache.append(buckets) + self.states_cache.append(states) + + def __getitem__(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return (self.buckets_cache[layer_idx], self.states_cache[layer_idx]) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.buckets_cache[layer_idx], self.states_cache[layer_idx]) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.states_cache) + + def update( + self, + buckets: torch.Tensor, + states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `ReformerDynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += states.shape[-2] + + # Update the cache + if states is not None: + if len(self.states_cache) <= layer_idx: + self.states_cache.append(states) + else: + self.states_cache[layer_idx] = torch.cat([self.states_cache[layer_idx], states], dim=1) + + if buckets is not None: + if len(self.buckets_cache) <= layer_idx: + self.buckets_cache.append(buckets) + else: + self.buckets_cache[layer_idx] = torch.cat([self.buckets_cache[layer_idx], buckets], dim=-1) + else: + # `ReformerLocalAttn` passes `None` to buckets as the module uses no buckets + self.buckets_cache.append(torch.tensor([])) + + return self.buckets_cache[layer_idx], self.states_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + return None + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: + """Converts the `ReformerDynamicCache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.buckets_cache[layer_idx], self.states_cache[layer_idx]),) + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, past_buckets_states: Optional[Tuple[Tuple[torch.FloatTensor, torch.FloatTensor]]] = None + ) -> "ReformerDynamicCache": + """Converts a cache in the legacy cache format into an equivalent `ReformerDynamicCache`. Used for + backward compatibility.""" + cache = cls() + if past_buckets_states is not None: + for layer_idx in range(len(past_buckets_states)): + buckets, states = past_buckets_states[layer_idx] + cache.update(buckets, states, layer_idx) + return cache + + def _stable_argsort(vector, dim): # this function scales the vector so that torch.argsort is stable. # torch.argsort is not stable on its own @@ -316,7 +428,7 @@ def _split_seq_length_dim_to(self, vectors, dim_factor_1, dim_factor_2, num_attn class LSHSelfAttention(nn.Module, EfficientAttentionMixin): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config @@ -328,6 +440,7 @@ def __init__(self, config): self.hash_seed = config.hash_seed self.is_decoder = config.is_decoder self.max_position_embeddings = config.max_position_embeddings + self.layer_idx = layer_idx self.dropout = config.lsh_attention_probs_dropout_prob @@ -356,6 +469,7 @@ def forward( past_buckets_states=None, use_cache=False, output_attentions=False, + cache_position=None, **kwargs, ): sequence_length = hidden_states.shape[1] @@ -364,16 +478,13 @@ def forward( # num hashes can optionally be overwritten by user num_hashes = num_hashes if num_hashes is not None else self.num_hashes - do_cached_attention = use_cache and past_buckets_states[1] is not None - # check if cache shall be used and that hidden states are already cached - if do_cached_attention: + exists_cache = past_buckets_states is not None and len(past_buckets_states) > self.layer_idx + if exists_cache: assert sequence_length == 1, ( "At the moment, auto-regressive language generation is only possible one word at a time. Make sure" f" that input sequence length {sequence_length} equals 1, when `past_buckets_states` is passed." ) - past_buckets = past_buckets_states[0] - past_states = past_buckets_states[1] # get query vector query_vectors = self.query_key(hidden_states) @@ -381,6 +492,9 @@ def forward( query_vectors, self.num_attention_heads, self.attention_head_size ) + past_buckets = past_buckets_states.buckets_cache[self.layer_idx] + past_states = past_buckets_states.states_cache[self.layer_idx] + if past_buckets is not None: key_value_hidden_states, sorted_bucket_idx, buckets = self._get_relevant_hid_states_and_buckets( query_vectors=query_vectors, @@ -425,7 +539,7 @@ def forward( value_vectors = self.value(hidden_states) # if query key is not already split - if not do_cached_attention or past_buckets is None: + if not use_cache or not exists_cache: query_key_vectors = self._split_hidden_size_dim( query_key_vectors, self.num_attention_heads, self.attention_head_size ) @@ -434,7 +548,7 @@ def forward( ) # cache buckets for next incremental decoding - if do_cached_attention and past_buckets is None and key_value_hidden_states.shape[1] >= self.chunk_length: + if exists_cache and key_value_hidden_states.shape[1] >= self.chunk_length: buckets = self._hash_vectors(query_key_vectors, num_hashes, attention_mask) # free memory @@ -448,7 +562,7 @@ def forward( ) do_standard_self_attention = (sequence_length <= self.chunk_length) or ( - use_cache and past_buckets_states[1] is not None + exists_cache and past_states is not None ) # LSH attention only makes sense if chunked attention should be performed if not do_standard_self_attention: @@ -498,7 +612,7 @@ def forward( "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and" " `config.num_chunks_before` are set to 0." ) - elif do_cached_attention and past_buckets is not None: + elif exists_cache and past_buckets is not None: # use max sequence length sorted_bucket_idx_per_hash = sorted_bucket_idx else: @@ -526,7 +640,7 @@ def forward( attention_mask=attention_mask, head_mask=head_mask, do_standard_self_attention=do_standard_self_attention, - do_cached_attention=do_cached_attention, + use_cache=use_cache, ) # free memory @@ -537,7 +651,7 @@ def forward( # sort clusters back to correct ordering out_vectors, logits = ReverseSort.apply(out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx) - if not do_standard_self_attention or (do_cached_attention and past_buckets is not None): + if not do_standard_self_attention or (exists_cache and past_buckets is not None): # sum up all hash rounds if num_hashes > 1: out_vectors = self._split_seq_length_dim_to( @@ -721,7 +835,7 @@ def _attend( attention_mask, head_mask, do_standard_self_attention, - do_cached_attention, + use_cache, ): # look at previous and following chunks if chunked attention if not do_standard_self_attention: @@ -741,12 +855,12 @@ def _attend( sorted_bucket_idx_per_hash, -1, self.chunk_length, self.num_attention_heads ) key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after) - elif do_cached_attention and query_key_dots.ndim > 4: + elif use_cache and query_key_dots.ndim > 4: key_value_bucket_idx = sorted_bucket_idx_per_hash query_bucket_idx = ( key_value_bucket_idx.new_ones(key_value_bucket_idx.shape[:-1] + (1,)) * key_value_bucket_idx.max() ) - elif do_cached_attention and query_key_dots.ndim <= 4: + elif use_cache and query_key_dots.ndim <= 4: query_bucket_idx = (query_key_dots.shape[-1] - 1) * torch.ones_like(query_key_dots)[:, :, :, -1] key_value_bucket_idx = torch.arange( query_key_dots.shape[-1], dtype=torch.long, device=query_key_dots.device @@ -762,7 +876,7 @@ def _attend( self_mask_value = self.self_mask_value_float32 mask_value = self.mask_value_float32 - if not do_cached_attention: + if not use_cache: mask = self._compute_attn_mask( query_bucket_idx, key_value_bucket_idx, @@ -1016,7 +1130,7 @@ def backward(ctx, grad_out_vectors, grad_logits): class LocalSelfAttention(nn.Module, EfficientAttentionMixin): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.num_attention_heads = config.num_attention_heads @@ -1029,6 +1143,7 @@ def __init__(self, config): self.attention_head_size = config.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size self.hidden_size = config.hidden_size + self.layer_idx = layer_idx # projection matrices self.query = nn.Linear(self.hidden_size, self.all_head_size, bias=False) @@ -1055,13 +1170,16 @@ def forward( batch_size = hidden_states.shape[0] # check if cache shall be used and that hidden states are already cached - if use_cache and past_buckets_states[1] is not None: - assert past_buckets_states[0] is None, ( + if past_buckets_states is not None and len(past_buckets_states) > self.layer_idx: + past_buckets = past_buckets_states.buckets_cache[self.layer_idx] + past_states = past_buckets_states.states_cache[self.layer_idx] + + assert past_buckets.numel() == 0, ( "LocalSelfAttention should not make use of `buckets`. There seems to be an error when caching" " hidden_states_and_buckets." ) key_value_hidden_states = self._retrieve_relevant_hidden_states( - past_buckets_states[1], self.chunk_length, self.num_chunks_before + past_states, self.chunk_length, self.num_chunks_before ) key_value_hidden_states = torch.cat([key_value_hidden_states, hidden_states], dim=1) @@ -1262,15 +1380,15 @@ def __init__(self, config, layer_id=0): self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "lsh": - self.self_attention = LSHSelfAttention(config) + self.self_attention = LSHSelfAttention(config, layer_idx=layer_id) elif len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "local": - self.self_attention = LocalSelfAttention(config) + self.self_attention = LocalSelfAttention(config, layer_idx=layer_id) elif len(set(self.attn_layers)) == 2 and set(self.attn_layers) == {"lsh", "local"}: # get correct attn layers if self.attn_layers[self.layer_id] == "lsh": - self.self_attention = LSHSelfAttention(config) + self.self_attention = LSHSelfAttention(config, layer_idx=layer_id) else: - self.self_attention = LocalSelfAttention(config) + self.self_attention = LocalSelfAttention(config, layer_idx=layer_id) else: raise NotImplementedError( f"Only attn layer types 'lsh' and 'local' exist, but got `config.attn_layers`: {self.attn_layers}. " @@ -1289,52 +1407,40 @@ def forward( orig_sequence_length=None, output_attentions=False, buckets=None, + cache_position=None, ): hidden_states = self.layer_norm(hidden_states) - # make sure cached hidden states is set to None for backward pass - if past_buckets_states is not None: - past_buckets_states_layer = past_buckets_states[self.layer_id] - else: - past_buckets_states_layer = None - # use cached buckets for backprob if buckets not None for LSHSelfAttention self_attention_outputs = self.self_attention( hidden_states=hidden_states, head_mask=head_mask, attention_mask=attention_mask, num_hashes=num_hashes, - past_buckets_states=past_buckets_states_layer, + past_buckets_states=past_buckets_states, use_cache=use_cache, output_attentions=output_attentions, buckets=buckets, + cache_position=cache_position, ) # add buckets if necessary if hasattr(self_attention_outputs, "buckets"): buckets = self_attention_outputs.buckets + buckets = buckets[:, :, :, :orig_sequence_length] if orig_sequence_length > 1 else buckets else: buckets = None # cache hidden states for future use - if use_cache: - if past_buckets_states[self.layer_id][0] is None: - # padded input should not be cached - past_buckets = ( - buckets[:, :, :, :orig_sequence_length] - if (buckets is not None and orig_sequence_length > 1) - else buckets - ) - else: - past_buckets = torch.cat([past_buckets_states[self.layer_id][0], buckets], dim=-1) - - if past_buckets_states[self.layer_id][1] is None: - # padded input should not be cached - past_states = hidden_states[:, :orig_sequence_length] - else: - past_states = torch.cat([past_buckets_states[self.layer_id][1], hidden_states], dim=1) + if use_cache and past_buckets_states is not None: + # padded input should not be cached during prefill + states = ( + hidden_states[:, :orig_sequence_length] if len(past_buckets_states) <= self.layer_id else hidden_states + ) + buckets, hidden_states = past_buckets_states.update( + buckets, states[:, :orig_sequence_length], self.layer_id + ) - past_buckets_states[self.layer_id] = (past_buckets, past_states) # compute attention feed forward output attention_output = self.output(self_attention_outputs.hidden_states) @@ -1708,8 +1814,15 @@ def forward( all_attentions = [] # init cached hidden states if necessary - if past_buckets_states is None: - past_buckets_states = [((None), (None)) for i in range(len(self.layers))] + return_legacy_cache = False + if use_cache or not isinstance(past_buckets_states, ReformerDynamicCache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `ReformerDynamicCache` instead, e.g. " + "`past_key_values=ReformerDynamicCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_buckets_states = ReformerDynamicCache.from_legacy_cache(past_buckets_states) # concat same tensor for reversible ResNet hidden_states = torch.cat([hidden_states, hidden_states], dim=-1) @@ -1734,11 +1847,15 @@ def forward( # Apply dropout hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + next_cache = past_buckets_states if use_cache else None + if return_legacy_cache: + next_cache = past_buckets_states.to_legacy_cache() + return ReformerEncoderOutput( hidden_states=hidden_states, all_hidden_states=all_hidden_states, all_attentions=all_attentions, - past_buckets_states=past_buckets_states, + past_buckets_states=next_cache, ) @@ -2250,7 +2367,7 @@ def _reorder_cache(self, past_key_values, beam_idx): reord_past_buckets_states = [] for layer_past in past_key_values: # buckets - if layer_past[0] is not None: + if layer_past[0].numel() != 0: reord_buckets = layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)) else: reord_buckets = None @@ -2258,6 +2375,9 @@ def _reorder_cache(self, past_key_values, beam_idx): # hidden states reord_hidden_states = layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)) reord_past_buckets_states.append((reord_buckets, reord_hidden_states)) + + if isinstance(past_key_values, ReformerDynamicCache): + reord_past_buckets_states = ReformerDynamicCache.from_legacy_cache(reord_past_buckets_states) return reord_past_buckets_states diff --git a/src/transformers/models/superglue/modeling_superglue.py b/src/transformers/models/superglue/modeling_superglue.py index 7a79603b3834..19cccbe48fe8 100644 --- a/src/transformers/models/superglue/modeling_superglue.py +++ b/src/transformers/models/superglue/modeling_superglue.py @@ -272,7 +272,6 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: - # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. From 204ed55d82900111322677126eea0a5a9fe547c1 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 10 Jun 2025 16:22:40 +0200 Subject: [PATCH 19/58] recurrent gemma `reorder_cache` was a no-op, delete --- .../models/recurrent_gemma/modeling_recurrent_gemma.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 315edd29864a..b573421c581f 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -815,15 +815,5 @@ def forward( hidden_states=outputs.hidden_states, ) - # Ignore copy - def _reorder_cache(self, past_key_values, beam_idx): - for layer in self.layers: - if hasattr(layer.temporal_block, "key_states"): - k_state = layer.temporal_block.key_states - v_state = layer.temporal_block.value_states - k_state = k_state.index_select(0, beam_idx.to(k_state.device)) - v_state = v_state.index_select(0, beam_idx.to(v_state.device)) - return None - __all__ = ["RecurrentGemmaForCausalLM", "RecurrentGemmaModel", "RecurrentGemmaPreTrainedModel"] From 7b61dfda9c526a2d03b64154b599cb61c0ba83f0 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 10 Jun 2025 16:54:05 +0200 Subject: [PATCH 20/58] fix-copies --- src/transformers/models/altclip/modeling_altclip.py | 10 +++++----- .../models/camembert/modeling_camembert.py | 10 +++++----- .../models/xlm_roberta/modeling_xlm_roberta.py | 10 +++++----- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 07f435868c64..8eceaf7d278b 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -379,11 +379,11 @@ def forward( ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 9c81e591c8c2..57d517ffefae 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -448,11 +448,11 @@ def forward( ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index b117be20954c..0475d438e128 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -448,11 +448,11 @@ def forward( ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) From 69c20ae2a206a1d75c282823b9b4e9810367d641 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 10 Jun 2025 17:34:24 +0200 Subject: [PATCH 21/58] fix blio and musicgen pipeline tests --- .../models/blip/modeling_blip_text.py | 81 ++++++++----------- .../models/musicgen/modeling_musicgen.py | 16 ++-- 2 files changed, 41 insertions(+), 56 deletions(-) diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index beb19d7366de..81f809841fa9 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -23,7 +23,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -288,20 +288,18 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - cache_position, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -367,8 +365,8 @@ def forward( ) -> Tuple[torch.Tensor]: self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, past_key_value=past_key_value, cache_position=cache_position, @@ -379,10 +377,9 @@ def forward( if encoder_hidden_states is not None: cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, past_key_value=past_key_value, output_attentions=output_attentions, cache_position=cache_position, @@ -434,14 +431,19 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - logger.warning_once( - "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " - "You should pass an instance of `EncoderDecoderCache` instead, e.g. " - "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." - ) - return_legacy_cache = True - past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + if use_cache: + if not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + # The model acts as encoder decoder but is not an encoder decoder. So we cast all cache objects to + # `EncoderDecoderCache` type assuming that the incoming cache is from `self_attention` + elif isinstance(past_key_values, DynamicCache): + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -957,32 +959,15 @@ def forward( def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): # Overwrite -- hardcoded key return (`is_decoder=True`) - input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_shape) - - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values.get_seq_length() + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + **model_kwargs, + ) + model_inputs["is_decoder"] = True - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), - "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), - "is_decoder": True, - } + return model_inputs __all__ = ["BlipTextModel", "BlipTextLMHeadModel", "BlipTextPreTrainedModel"] diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 88fab56918aa..af7dc25af89c 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1897,6 +1897,7 @@ def prepare_inputs_for_generation( encoder_outputs=None, decoder_delay_pattern_mask=None, guidance_scale=None, + cache_position=None, **kwargs, ): # Overwritten -- MusicGen has custom processing @@ -1918,16 +1919,15 @@ def prepare_inputs_for_generation( decoder_attention_mask = decoder_attention_mask.repeat((2, 1)) if past_key_values is not None: - past_length = past_key_values.get_seq_length() - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length + if cache_position[-1] >= decoder_input_ids.shape[1]: + decoder_input_ids = decoder_input_ids[:, -cache_position.shape[0] :] + elif ( + decoder_input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + decoder_input_ids = decoder_input_ids[:, cache_position] else: # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + decoder_input_ids = decoder_input_ids[:, -1:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed From b5088140773013d0b1ded55bdf4abc54f4b48624 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 11 Jun 2025 10:53:55 +0200 Subject: [PATCH 22/58] fix reformer --- .../models/reformer/modeling_reformer.py | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 3abad3258ffc..27faaf84295a 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -143,7 +143,7 @@ def update( self.buckets_cache[layer_idx] = torch.cat([self.buckets_cache[layer_idx], buckets], dim=-1) else: # `ReformerLocalAttn` passes `None` to buckets as the module uses no buckets - self.buckets_cache.append(torch.tensor([])) + self.buckets_cache.append(torch.tensor([], device=self.states_cache[layer_idx].device)) return self.buckets_cache[layer_idx], self.states_cache[layer_idx] @@ -155,7 +155,9 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: backward compatibility.""" legacy_cache = () for layer_idx in range(len(self)): - legacy_cache += ((self.buckets_cache[layer_idx], self.states_cache[layer_idx]),) + buckets, states = self.buckets_cache[layer_idx], self.states_cache[layer_idx] + buckets = buckets if buckets.numel() != 0 else None + legacy_cache += ((buckets, states),) return legacy_cache @classmethod @@ -495,7 +497,7 @@ def forward( past_buckets = past_buckets_states.buckets_cache[self.layer_idx] past_states = past_buckets_states.states_cache[self.layer_idx] - if past_buckets is not None: + if past_buckets.numel() != 0: key_value_hidden_states, sorted_bucket_idx, buckets = self._get_relevant_hid_states_and_buckets( query_vectors=query_vectors, attention_mask=attention_mask, @@ -539,7 +541,7 @@ def forward( value_vectors = self.value(hidden_states) # if query key is not already split - if not use_cache or not exists_cache: + if not exists_cache or past_buckets.numel() == 0: query_key_vectors = self._split_hidden_size_dim( query_key_vectors, self.num_attention_heads, self.attention_head_size ) @@ -612,7 +614,7 @@ def forward( "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and" " `config.num_chunks_before` are set to 0." ) - elif exists_cache and past_buckets is not None: + elif exists_cache and past_buckets.numel() != 0: # use max sequence length sorted_bucket_idx_per_hash = sorted_bucket_idx else: @@ -640,7 +642,7 @@ def forward( attention_mask=attention_mask, head_mask=head_mask, do_standard_self_attention=do_standard_self_attention, - use_cache=use_cache, + use_cache=exists_cache, ) # free memory @@ -651,7 +653,7 @@ def forward( # sort clusters back to correct ordering out_vectors, logits = ReverseSort.apply(out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx) - if not do_standard_self_attention or (exists_cache and past_buckets is not None): + if not do_standard_self_attention or (exists_cache and past_buckets.numel() != 0): # sum up all hash rounds if num_hashes > 1: out_vectors = self._split_seq_length_dim_to( @@ -1427,7 +1429,6 @@ def forward( # add buckets if necessary if hasattr(self_attention_outputs, "buckets"): buckets = self_attention_outputs.buckets - buckets = buckets[:, :, :, :orig_sequence_length] if orig_sequence_length > 1 else buckets else: buckets = None @@ -1435,7 +1436,16 @@ def forward( if use_cache and past_buckets_states is not None: # padded input should not be cached during prefill states = ( - hidden_states[:, :orig_sequence_length] if len(past_buckets_states) <= self.layer_id else hidden_states + hidden_states[:, :orig_sequence_length] + if len(past_buckets_states.states_cache) <= self.layer_id + else hidden_states + ) + buckets = ( + buckets[:, :orig_sequence_length] + if len(past_buckets_states.buckets_cache) <= self.layer_id + and buckets is not None + and orig_sequence_length > 1 + else buckets ) buckets, hidden_states = past_buckets_states.update( buckets, states[:, :orig_sequence_length], self.layer_id @@ -2222,7 +2232,7 @@ def _pad_to_mult_of_chunk_length( # Extend `inputs_embeds` with padding to match least common multiple chunk_length if inputs_embeds is not None: - padded_inputs_embeds = self.embeddings(padded_input_ids, position_ids) + padded_inputs_embeds = self.get_input_embeddings()(padded_input_ids) inputs_embeds = torch.cat([inputs_embeds, padded_inputs_embeds], dim=-2) input_shape = inputs_embeds.size() return input_ids, inputs_embeds, attention_mask, position_ids, input_shape From b7deae60a774c3cb88358ead9f2ac8a3f9f92a0a Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 11 Jun 2025 15:57:38 +0200 Subject: [PATCH 23/58] fix reformer, again... --- src/transformers/models/reformer/modeling_reformer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 27faaf84295a..e3082f1c063e 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -1441,10 +1441,12 @@ def forward( else hidden_states ) buckets = ( - buckets[:, :orig_sequence_length] - if len(past_buckets_states.buckets_cache) <= self.layer_id - and buckets is not None - and orig_sequence_length > 1 + buckets[:, :, :, :orig_sequence_length] + if ( + len(past_buckets_states.buckets_cache) <= self.layer_id + and buckets is not None + and orig_sequence_length > 1 + ) else buckets ) buckets, hidden_states = past_buckets_states.update( From ae88ecc83cfea5a9c32a048c447f77853d887a83 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 11 Jun 2025 16:45:08 +0200 Subject: [PATCH 24/58] delete `_supports_cache_class` --- src/transformers/generation/utils.py | 7 +++---- src/transformers/models/aria/modeling_aria.py | 4 ++-- src/transformers/models/aria/modular_aria.py | 2 +- src/transformers/models/aya_vision/modeling_aya_vision.py | 2 +- src/transformers/models/bamba/modeling_bamba.py | 2 +- src/transformers/models/bamba/modular_bamba.py | 2 +- src/transformers/models/bart/modeling_bart.py | 2 +- .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 2 +- src/transformers/models/biogpt/modeling_biogpt.py | 2 +- src/transformers/models/biogpt/modular_biogpt.py | 2 +- src/transformers/models/bitnet/modeling_bitnet.py | 2 +- src/transformers/models/blenderbot/modeling_blenderbot.py | 2 +- .../models/blenderbot_small/modeling_blenderbot_small.py | 2 +- src/transformers/models/blip_2/modeling_blip_2.py | 2 +- src/transformers/models/bloom/modeling_bloom.py | 2 +- src/transformers/models/chameleon/modeling_chameleon.py | 2 +- src/transformers/models/codegen/modeling_codegen.py | 2 +- src/transformers/models/cohere/modeling_cohere.py | 2 +- src/transformers/models/cohere2/modeling_cohere2.py | 2 +- src/transformers/models/colqwen2/modeling_colqwen2.py | 1 - src/transformers/models/colqwen2/modular_colqwen2.py | 1 - src/transformers/models/csm/modeling_csm.py | 2 +- src/transformers/models/csm/modular_csm.py | 2 +- src/transformers/models/dbrx/modeling_dbrx.py | 2 +- .../decision_transformer/modeling_decision_transformer.py | 2 +- .../models/deepseek_v3/modeling_deepseek_v3.py | 2 +- src/transformers/models/diffllama/modeling_diffllama.py | 2 +- src/transformers/models/emu3/modeling_emu3.py | 2 +- src/transformers/models/falcon/modeling_falcon.py | 2 +- src/transformers/models/falcon_h1/modeling_falcon_h1.py | 2 +- src/transformers/models/falcon_h1/modular_falcon_h1.py | 2 +- src/transformers/models/gemma/modeling_gemma.py | 2 +- src/transformers/models/gemma2/modeling_gemma2.py | 2 +- src/transformers/models/gemma3/modeling_gemma3.py | 2 +- src/transformers/models/git/modeling_git.py | 2 +- src/transformers/models/glm/modeling_glm.py | 2 +- src/transformers/models/glm4/modeling_glm4.py | 2 +- src/transformers/models/got_ocr2/modeling_got_ocr2.py | 2 +- src/transformers/models/gpt2/modeling_gpt2.py | 2 +- src/transformers/models/gpt_neo/modeling_gpt_neo.py | 2 +- src/transformers/models/gpt_neox/modeling_gpt_neox.py | 2 +- .../models/gpt_neox_japanese/modeling_gpt_neox_japanese.py | 2 +- src/transformers/models/gptj/modeling_gptj.py | 2 +- src/transformers/models/granite/modeling_granite.py | 2 +- .../models/granite_speech/modeling_granite_speech.py | 2 +- src/transformers/models/granitemoe/modeling_granitemoe.py | 2 +- .../models/granitemoehybrid/modeling_granitemoehybrid.py | 2 +- .../models/granitemoeshared/modeling_granitemoeshared.py | 2 +- src/transformers/models/helium/modeling_helium.py | 2 +- src/transformers/models/idefics/modeling_idefics.py | 2 +- src/transformers/models/idefics2/modeling_idefics2.py | 2 +- src/transformers/models/idefics3/modeling_idefics3.py | 2 +- .../models/instructblip/modeling_instructblip.py | 4 ++-- .../models/instructblipvideo/modeling_instructblipvideo.py | 4 ++-- src/transformers/models/internvl/modeling_internvl.py | 2 +- src/transformers/models/jamba/modeling_jamba.py | 2 +- src/transformers/models/janus/modeling_janus.py | 2 +- src/transformers/models/janus/modular_janus.py | 2 +- src/transformers/models/jetmoe/modeling_jetmoe.py | 1 - src/transformers/models/llama/modeling_llama.py | 2 +- src/transformers/models/llama4/modeling_llama4.py | 2 +- src/transformers/models/llava/modeling_llava.py | 2 +- src/transformers/models/llava_next/modeling_llava_next.py | 2 +- .../models/llava_next_video/modeling_llava_next_video.py | 2 +- .../models/llava_onevision/modeling_llava_onevision.py | 2 +- src/transformers/models/longt5/modeling_longt5.py | 2 +- src/transformers/models/m2m_100/modeling_m2m_100.py | 2 +- src/transformers/models/marian/modeling_marian.py | 2 +- src/transformers/models/mbart/modeling_mbart.py | 2 +- src/transformers/models/mimi/modeling_mimi.py | 2 +- src/transformers/models/minimax/modeling_minimax.py | 2 +- src/transformers/models/minimax/modular_minimax.py | 2 +- src/transformers/models/mistral/modeling_mistral.py | 2 +- src/transformers/models/mistral3/modeling_mistral3.py | 2 +- src/transformers/models/mixtral/modeling_mixtral.py | 2 +- src/transformers/models/mllama/modeling_mllama.py | 2 +- src/transformers/models/moonshine/modeling_moonshine.py | 2 +- src/transformers/models/moonshine/modular_moonshine.py | 2 +- src/transformers/models/moshi/modeling_moshi.py | 2 +- src/transformers/models/mt5/modeling_mt5.py | 2 +- src/transformers/models/nemotron/modeling_nemotron.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 2 +- src/transformers/models/olmo2/modeling_olmo2.py | 2 +- src/transformers/models/olmoe/modeling_olmoe.py | 2 +- src/transformers/models/opt/modeling_opt.py | 2 +- src/transformers/models/paligemma/modeling_paligemma.py | 2 +- src/transformers/models/pegasus/modeling_pegasus.py | 2 +- src/transformers/models/pegasus_x/modeling_pegasus_x.py | 2 +- src/transformers/models/persimmon/modeling_persimmon.py | 2 +- src/transformers/models/phi/modeling_phi.py | 2 +- src/transformers/models/phi3/modeling_phi3.py | 2 +- .../models/phi4_multimodal/modeling_phi4_multimodal.py | 2 +- src/transformers/models/phimoe/modeling_phimoe.py | 2 +- src/transformers/models/pix2struct/modeling_pix2struct.py | 2 +- src/transformers/models/pop2piano/modeling_pop2piano.py | 2 +- src/transformers/models/qwen2/modeling_qwen2.py | 2 +- .../models/qwen2_5_omni/modeling_qwen2_5_omni.py | 2 +- src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 +- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 1 - src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 2 +- src/transformers/models/qwen3/modeling_qwen3.py | 2 +- src/transformers/models/qwen3_moe/modeling_qwen3_moe.py | 2 +- .../models/recurrent_gemma/modeling_recurrent_gemma.py | 2 +- src/transformers/models/smolvlm/modeling_smolvlm.py | 2 +- src/transformers/models/stablelm/modeling_stablelm.py | 2 +- src/transformers/models/starcoder2/modeling_starcoder2.py | 2 +- .../switch_transformers/modeling_switch_transformers.py | 2 +- src/transformers/models/t5/modeling_t5.py | 2 +- src/transformers/models/udop/modeling_udop.py | 2 +- src/transformers/models/umt5/modeling_umt5.py | 2 +- .../models/video_llava/modeling_video_llava.py | 2 +- src/transformers/models/vipllava/modeling_vipllava.py | 2 +- src/transformers/models/whisper/modeling_whisper.py | 2 +- src/transformers/models/zamba/modeling_zamba.py | 2 +- src/transformers/models/zamba2/modeling_zamba2.py | 2 +- src/transformers/models/zamba2/modular_zamba2.py | 2 +- src/transformers/utils/args_doc.py | 3 --- 117 files changed, 117 insertions(+), 125 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 4f0f3bda3539..f69223aac3af 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1944,10 +1944,9 @@ def _get_cache( def _supports_default_dynamic_cache(self) -> bool: """ Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`. - This is mostly the same as `_supports_cache_class` attribute, but add exception for `Jamba` model which - uses its own `HybridMambaAttentionDynamicCache` and do not need to initialize the Cache in advance in - order to save memory (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed - for `HybridMambaAttentionDynamicCache`). + This adds exception for some models like `Jamba` model which uses its own `HybridMambaAttentionDynamicCache` + and do not need to initialize the Cache in advance in order to save memory (because no back and forth + `to_legacy_cache` and `from_legacy_cache` will be performed for `HybridMambaAttentionDynamicCache`). """ return all( special_model_name not in self.__class__.__name__.lower() diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index ebddfcfe7115..4fa0df8d5843 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -641,7 +641,7 @@ class AriaTextPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = False _supports_sdpa = True - _supports_cache_class = True + _supports_attention_backend = True def _init_weights(self, module): @@ -670,7 +670,7 @@ class AriaPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing) _supports_attention_backend = True diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index c80351cd9a81..503b7bac3d69 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -1277,7 +1277,7 @@ class AriaTextPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = False _supports_sdpa = True - _supports_cache_class = True + _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index 1d723f1b5aee..f6c3048c491f 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -92,7 +92,7 @@ class AyaVisionPreTrainedModel(PreTrainedModel): base_model_prefix = "" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn_2 = True _supports_sdpa = True _supports_quantized_cache = False diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 1b8e12d1c3b2..53a9a01e4492 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -1043,7 +1043,7 @@ class BambaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache + # Note: only supports HybridMambaAttentionDynamicCache _is_stateful = True def _init_weights(self, module): diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 9db52ebfbc5d..34956388f38e 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -816,7 +816,7 @@ class BambaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache + # Note: only supports HybridMambaAttentionDynamicCache _is_stateful = True def _init_weights(self, module): diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index c5269516eaf6..f3375ed57c8e 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -495,7 +495,7 @@ class BartPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index df30b5f7fa68..4537c94ad78d 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1559,7 +1559,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): _no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_param_buffer_assignment = False - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 2261684e3e2f..04bd54a9481d 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -349,7 +349,7 @@ class BioGptPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index a498ed15ee0e..cd9f890cc55c 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -176,7 +176,7 @@ class BioGptPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index 661a3c9bb60e..d06b5e24596b 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -321,7 +321,7 @@ class BitNetPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 97df16a0acfd..9e2531cacfcd 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -465,7 +465,7 @@ class BlenderbotPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 7115b9d3eb10..e612b5c10b65 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -453,7 +453,7 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index a0ecbf243c27..11cb72eccafa 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -1874,7 +1874,7 @@ def forward( class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): config_class = Blip2Config main_input_name = "pixel_values" - _supports_cache_class = True + _supports_static_cache = True _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) _keep_in_fp32_modules = ["query_tokens", "qformer"] diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index f93ee1e60d6d..2e88c1ebc3ac 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -446,7 +446,7 @@ class BloomPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["BloomBlock"] _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_static_cache = True _supports_quantized_cache = True diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 5644f68baa92..a142dba796b9 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -827,7 +827,7 @@ class ChameleonPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_quantized_cache = True - _supports_cache_class = True + _supports_static_cache = True _supports_param_buffer_assignment = False _supports_flex_attn = True diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 8fca177b7584..719b93945d76 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -298,7 +298,7 @@ class CodeGenPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["CodeGenBlock"] _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 0700eb8e9f60..3c9a6a20fbeb 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -358,7 +358,7 @@ class CoherePreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 5690864cfc55..7c3391b66757 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -337,7 +337,7 @@ class Cohere2PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/colqwen2/modeling_colqwen2.py b/src/transformers/models/colqwen2/modeling_colqwen2.py index c4fdc63567be..2d68ae208b0f 100644 --- a/src/transformers/models/colqwen2/modeling_colqwen2.py +++ b/src/transformers/models/colqwen2/modeling_colqwen2.py @@ -42,7 +42,6 @@ class ColQwen2PreTrainedModel(PreTrainedModel): _no_split_modules = [] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True def _init_weights(self, module): std = ( diff --git a/src/transformers/models/colqwen2/modular_colqwen2.py b/src/transformers/models/colqwen2/modular_colqwen2.py index 43c4cc5308dd..ea774e67ab59 100644 --- a/src/transformers/models/colqwen2/modular_colqwen2.py +++ b/src/transformers/models/colqwen2/modular_colqwen2.py @@ -227,7 +227,6 @@ def __call__( class ColQwen2PreTrainedModel(ColPaliPreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True @dataclass diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index c0c4f5927a57..aff8afc14df7 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -120,7 +120,7 @@ class CsmPreTrainedModel(PreTrainedModel): _supports_sdpa = True # does not because of Mimi codec model # _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/csm/modular_csm.py b/src/transformers/models/csm/modular_csm.py index 4322a2a07f8c..1e73dc1c364f 100644 --- a/src/transformers/models/csm/modular_csm.py +++ b/src/transformers/models/csm/modular_csm.py @@ -120,7 +120,7 @@ class CsmPreTrainedModel(PreTrainedModel): _supports_sdpa = True # does not because of Mimi codec model # _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 0a530e87ae1b..a35d4cf70af4 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -808,7 +808,7 @@ class DbrxPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index c577d6f17c65..3d1671d0b5f3 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -451,7 +451,7 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True - _supports_cache_class = True + _supports_static_cache = False def __init__(self, *inputs, **kwargs): diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 5804eeee4b17..c69d0a065bf6 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -500,7 +500,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 68aa54180caa..fef158efff93 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -554,7 +554,7 @@ class DiffLlamaPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = False - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = False diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 995e9cac7d60..47fa6411e923 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1139,7 +1139,7 @@ class Emu3PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_quantized_cache = True - _supports_cache_class = True + _supports_static_cache = True _supports_param_buffer_assignment = False _supports_flex_attn = True diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index ee013faf9b4d..be4c3187582c 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -657,7 +657,7 @@ class FalconPreTrainedModel(PreTrainedModel): _no_split_modules = ["FalconDecoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 0a4d8f432777..b8d70deafbd8 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -1161,7 +1161,7 @@ class FalconH1PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True # Note: only supports FalconHybridMambaAttentionDynamicCache + # Note: only supports FalconHybridMambaAttentionDynamicCache _is_stateful = True def _init_weights(self, module): diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index bd0ecb1804d1..fb8fefbd5ff9 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -936,7 +936,7 @@ class FalconH1PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True # Note: only supports FalconHybridMambaAttentionDynamicCache + # Note: only supports FalconHybridMambaAttentionDynamicCache _is_stateful = True def _init_weights(self, module): diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 95b50039eb9b..c83b42c3a6cc 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -321,7 +321,7 @@ class GemmaPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 7bb865bc5dcd..bb389f6626a5 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -342,7 +342,7 @@ class Gemma2PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 08740173009d..b0099bbb8de5 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -445,7 +445,7 @@ class Gemma3PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 2282b928075f..519918e24d8d 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -497,7 +497,7 @@ class GitPreTrainedModel(PreTrainedModel): config_class = GitConfig base_model_prefix = "git" supports_gradient_checkpointing = True - _supports_cache_class = True + _supports_quantized_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 235f8258c10c..996150ee881a 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -338,7 +338,7 @@ class GlmPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index 9a22b2561751..dab63c4cb0f9 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -346,7 +346,7 @@ class Glm4PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index fc3ab807df82..644ecc2e8fcb 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -588,7 +588,7 @@ class GotOcr2PreTrainedModel(PreTrainedModel): base_model_prefix = "" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn_2 = True _supports_sdpa = True _supports_quantized_cache = True diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index d4f6732db6a7..c0c76146b8d3 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -561,7 +561,7 @@ class GPT2PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_attention_backend = True - _supports_cache_class = True + _supports_static_cache = True def __init__(self, *inputs, **kwargs): diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index f37b503e8e54..fae438f9c16b 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -485,7 +485,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTNeoBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = False # TODO: needs a HybridCache diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 16de0f23db91..2ff9b7603383 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -294,7 +294,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 9d15ef357a67..dbfd947af4a3 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -47,7 +47,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel): base_model_prefix = "gpt_neox_japanese" _no_split_modules = ["GPTNeoXJapaneseLayer"] _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 5434c75ade07..7454ec3a497d 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -483,7 +483,7 @@ class GPTJPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTJBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_param_buffer_assignment = False diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 11f2873f3dfc..77fb9811fb1d 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -308,7 +308,7 @@ class GranitePreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/granite_speech/modeling_granite_speech.py b/src/transformers/models/granite_speech/modeling_granite_speech.py index bfce41e6fac0..1a5fb32fd4e6 100644 --- a/src/transformers/models/granite_speech/modeling_granite_speech.py +++ b/src/transformers/models/granite_speech/modeling_granite_speech.py @@ -287,7 +287,7 @@ def forward(self, hidden_states: torch.Tensor): @auto_docstring class GraniteSpeechPreTrainedModel(PreTrainedModel): config_class = GraniteSpeechConfig - _supports_cache_class = True + _supports_flash_attn_2 = True _supports_sdpa = True diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index e6833677a776..3953d3064e88 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -597,7 +597,7 @@ class GraniteMoePreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index c7fac08e57c1..8bbdbdba3760 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -1159,7 +1159,7 @@ class GraniteMoeHybridPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _is_stateful = True diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 272618d1fe3c..f47ab9107413 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -511,7 +511,7 @@ class GraniteMoeSharedPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index b9cb3bafc13d..5c5a632c9066 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -323,7 +323,7 @@ class HeliumPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 6b6fc0df056b..82b39576248b 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -916,7 +916,7 @@ class IdeficsPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"] _supports_sdpa = True - _supports_cache_class = True + _supports_flash_attn_2 = True _supports_static_cache = False # IDEFICS cannot compile due to dynamic control flow when checking inputs _supports_attention_backend = True diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 1132276ef1ac..6d54f044d3f8 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -487,7 +487,7 @@ class Idefics2PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 41043b3e7455..f754981b32d3 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -505,7 +505,7 @@ class Idefics3PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 06e92a170178..bf654690cc09 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -335,7 +335,7 @@ class InstructBlipPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) @@ -1375,7 +1375,7 @@ def forward( class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin): config_class = InstructBlipConfig main_input_name = "pixel_values" - _supports_cache_class = True + _supports_static_cache = True _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8 diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index 6f4d599cada0..2d4090754824 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -852,7 +852,7 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) @@ -1381,7 +1381,7 @@ def forward( class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel, GenerationMixin): config_class = InstructBlipVideoConfig main_input_name = "pixel_values" - _supports_cache_class = True + _supports_static_cache = True _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8 diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index 4c747a2394e5..1f12c267a072 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -534,7 +534,7 @@ class InternVLPreTrainedModel(PreTrainedModel): base_model_prefix = "" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn_2 = True _supports_sdpa = True _supports_quantized_cache = True diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index d60190161ef6..8eb6799ebae4 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1065,7 +1065,7 @@ class JambaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache + # Note: only supports HybridMambaAttentionDynamicCache _is_stateful = True def _init_weights(self, module): diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index a526ce5d7af1..7a515c0efdf7 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -63,7 +63,7 @@ class JanusPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_quantized_cache = True - _supports_cache_class = True + _supports_static_cache = True _supports_param_buffer_assignment = False diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 599348026d16..737cc7187f2b 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -385,7 +385,7 @@ class JanusPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_quantized_cache = True - _supports_cache_class = True + _supports_static_cache = True _supports_param_buffer_assignment = False diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 788b2066b5dc..c33ddff23144 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -823,7 +823,6 @@ class JetMoePreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True def _init_weights(self, module): """Initialize the weights.""" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 4502cee6e573..41245cabf49b 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -327,7 +327,7 @@ class LlamaPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index fe77ea4a58c8..50b52f232590 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -435,7 +435,7 @@ class Llama4PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = False _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 4321dea59d94..70fcbf7a8127 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -136,7 +136,7 @@ class LlavaPreTrainedModel(PreTrainedModel): base_model_prefix = "" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn_2 = True _supports_sdpa = True _supports_quantized_cache = True diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index e5b597e819ed..3f56da38ddac 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -247,7 +247,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn_2 = True _supports_sdpa = True _supports_quantized_cache = True diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 368fd09ef32f..81c791d5c22c 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -190,7 +190,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn_2 = True _supports_sdpa = True _supports_quantized_cache = True diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 5a1157edebf3..96fe00778783 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -303,7 +303,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn_2 = True _supports_sdpa = True _supports_quantized_cache = True diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 1e64e1b85ada..92efcecff1c4 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1251,7 +1251,7 @@ class LongT5PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["LongT5Block"] - _supports_cache_class = True + _supports_static_cache = False # TODO: @raushan more involved due to local/global attn @property diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 7e8dcd4f47ad..582234615c9d 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -531,7 +531,7 @@ class M2M100PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + # Doesn't support `compile` (dynamic control flow). Can be fixed but low usage model _supports_static_cache = False diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 9d613f6ccccc..c256d4b548e2 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -469,7 +469,7 @@ class MarianPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalPositionalEmbedding]): diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 25a73249853a..8d310cb69e9a 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -499,7 +499,7 @@ class MBartPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 4d7b92979ab1..0d1283f12684 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1437,7 +1437,7 @@ class MimiPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 18d2e4df7d9c..fb450cfd6a1e 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -592,7 +592,7 @@ class MiniMaxPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True # Note: only supports MiniMaxCache + # Note: only supports MiniMaxCache _supports_quantized_cache = False _supports_static_cache = False _supports_attention_backend = True diff --git a/src/transformers/models/minimax/modular_minimax.py b/src/transformers/models/minimax/modular_minimax.py index 0028dcbfb6c0..186c9961e1b2 100644 --- a/src/transformers/models/minimax/modular_minimax.py +++ b/src/transformers/models/minimax/modular_minimax.py @@ -472,7 +472,7 @@ def forward( class MiniMaxPreTrainedModel(MixtralPreTrainedModel): - _supports_cache_class = True # Note: only supports MiniMaxCache + # Note: only supports MiniMaxCache _supports_static_cache = False _supports_quantized_cache = False diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 90881cbcd2b9..cfb534256767 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -265,7 +265,7 @@ class MistralPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index 082020b3afd0..70ecc3b38fdf 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -201,7 +201,7 @@ class Mistral3PreTrainedModel(PreTrainedModel): base_model_prefix = "" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn_2 = True _supports_sdpa = True _supports_quantized_cache = True diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 9147538f73cc..9c66c2f6fc6b 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -420,7 +420,7 @@ class MixtralPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 4c73295cab65..78a925fe5f85 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -855,7 +855,7 @@ class MllamaPreTrainedModel(PreTrainedModel): "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer", ] - _supports_cache_class = True + _supports_static_cache = False # static cache cannot have different shapes for each layer _supports_sdpa = True _supports_flash_attn_2 = True diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index a3aebaed9a50..abbd1ba5154c 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -484,7 +484,7 @@ class MoonshinePreTrainedModel(PreTrainedModel): _no_split_modules = ["MoonshineEncoderLayer", "MoonshineDecoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index f99de20eb02c..820ab5f8d2ba 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -508,7 +508,7 @@ class MoonshinePreTrainedModel(PreTrainedModel): _no_split_modules = ["MoonshineEncoderLayer", "MoonshineDecoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index fdaa5246064b..e108e7147e95 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -828,7 +828,7 @@ class MoshiPreTrainedModel(PreTrainedModel): _no_split_modules = ["MoshiDecoderLayer", "MimiTransformerLayer"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True + main_input_name = "input_ids" def _init_weights(self, module): diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 36c4b9f3ba7f..a4764ef0e9ea 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -758,7 +758,7 @@ class MT5PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_quantized_cache = False # enc-dec models don't support yet _supports_static_cache = True - _supports_cache_class = True + _no_split_modules = ["MT5Block"] _keep_in_fp32_modules = ["wo"] diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index ea0f0de04569..f6f7fb7f6088 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -577,7 +577,7 @@ class NemotronPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 36999733b3a9..4c413f04c659 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -304,7 +304,7 @@ class OlmoPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 661a9341d67a..5a4b5aaccfb0 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -308,7 +308,7 @@ class Olmo2PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 88f884dc2e43..4b274ea95c71 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -706,7 +706,7 @@ class OlmoePreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 7413c92630c7..0fcd9bff7c90 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -318,7 +318,7 @@ class OPTPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index addba2b30fef..0192ba74502f 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -126,7 +126,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["PaliGemmaMultiModalProjector"] _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_flash_attn_2 = True diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 9ade40688847..e084d6e2377b 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -464,7 +464,7 @@ class PegasusPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 549c27cbab16..d20a43d47618 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -765,7 +765,7 @@ class PegasusXPreTrainedModel(PreTrainedModel): # Flaky logits _supports_sdpa = False _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index a5376d651cee..ac8bcb766cc9 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -389,7 +389,7 @@ class PersimmonPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["PersimmonDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_sdpa = True diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 3f2deffd9e0d..c6fdddfa658a 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -298,7 +298,7 @@ class PhiPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 08f93a468b5f..5c7ec6fd4c01 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -319,7 +319,7 @@ class Phi3PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index d484f9255a8d..d3c143c2e579 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -1625,7 +1625,7 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index e81d38e2d88d..e23dadb7fb5d 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -884,7 +884,7 @@ class PhimoePreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 413cc77e4461..90da4f44797c 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -361,7 +361,7 @@ def forward( @auto_docstring class Pix2StructPreTrainedModel(PreTrainedModel): config_class = Pix2StructConfig - _supports_cache_class = True + _supports_static_cache = False @property diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 1d309e48d802..f644c683880a 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -577,7 +577,7 @@ class Pop2PianoPreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = False supports_gradient_checkpointing = True - _supports_cache_class = True + _supports_static_cache = False _no_split_modules = ["Pop2PianoBlock"] _keep_in_fp32_modules = ["wo"] diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 03df9df94f6c..2a1095e01d27 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -269,7 +269,7 @@ class Qwen2PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 8164ad0f4bac..b9afe03bfe71 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -98,7 +98,7 @@ class Qwen2_5OmniPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True + _supports_static_cache = False _supports_attention_backend = True diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 1b4cfea5b665..5ec53599d96d 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -350,7 +350,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True + _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 7f81c331ccc9..0e766f569dd8 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -741,7 +741,6 @@ class Qwen2MoePreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index a995d064b529..652cc6bcc6e7 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -732,7 +732,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True + _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index ef1cd22e0c04..93723de174f3 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -295,7 +295,7 @@ class Qwen3PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 77b2362fe192..f25cc6609ae8 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -425,7 +425,7 @@ class Qwen3MoePreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index b573421c581f..97818aff0aeb 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -514,7 +514,7 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["cache"] _supports_flash_attn_2 = False _supports_sdpa = False # we can't compare with eager for now - _supports_cache_class = True + _supports_quantized_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index bbf05404ac12..b92754444b7e 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -51,7 +51,7 @@ class SmolVLMPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index be9af304d51d..57442a7a44dc 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -619,7 +619,7 @@ class StableLmPreTrainedModel(PreTrainedModel): _no_split_modules = ["StableLmDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True - _supports_cache_class = True + _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 6e102d801426..15b79b795c17 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -302,7 +302,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index d6dba63ee428..e2e167c18ba3 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -767,7 +767,7 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel): config_class = SwitchTransformersConfig base_model_prefix = "switch_transformers" supports_gradient_checkpointing = True - _supports_cache_class = True + _supports_static_cache = False _no_split_modules = ["SwitchTransformersBlock"] diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index b6dcfa9548b8..eaeb737966c0 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -774,7 +774,7 @@ class T5PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_quantized_cache = False # enc-dec models don't support yet _supports_static_cache = True - _supports_cache_class = True + _no_split_modules = ["T5Block"] _keep_in_fp32_modules = ["wo"] diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 1e2272019685..8b4db278b8d5 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -250,7 +250,7 @@ class UdopPreTrainedModel(PreTrainedModel): config_class = UdopConfig base_model_prefix = "transformer" supports_gradient_checkpointing = True - _supports_cache_class = True + _supports_static_cache = False _keep_in_fp32_modules = ["wo"] diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index f1c4c22a77d4..9ba0c7c6e6e0 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -506,7 +506,7 @@ class UMT5PreTrainedModel(PreTrainedModel): config_class = UMT5Config base_model_prefix = "transformer" supports_gradient_checkpointing = True - _supports_cache_class = True + _supports_static_cache = True _no_split_modules = ["UMT5Block"] _keep_in_fp32_modules = ["wo"] diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index ed7a19ca6645..37130a2466f7 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -150,7 +150,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["VideoLlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn_2 = True _supports_sdpa = True _supports_quantized_cache = True diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index c4a20aef9148..6b303c47ed77 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -137,7 +137,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel): base_model_prefix = "" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn_2 = True _supports_sdpa = True _supports_quantized_cache = True diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 97c679dbb4a6..1b7eba6bd6ab 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -561,7 +561,7 @@ class WhisperPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 7733decd0198..79fb3a5383e7 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -794,7 +794,7 @@ class ZambaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = False _supports_sdpa = False - _supports_cache_class = True # Note: only supports ZambaHybridDynamicCache + # Note: only supports ZambaHybridDynamicCache _is_stateful = True def _init_weights(self, module): diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index b2aed168239f..fb418a67ceb2 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1196,7 +1196,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_flex_attn = True _supports_sdpa = True - _supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache + # Note: only supports Zamba2HybridDynamicCache _is_stateful = True def _init_weights(self, module): diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index c4e14dd14824..6ef608f7f3db 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -911,7 +911,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_flex_attn = True _supports_sdpa = True - _supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache + # Note: only supports Zamba2HybridDynamicCache _is_stateful = True def _init_weights(self, module): diff --git a/src/transformers/utils/args_doc.py b/src/transformers/utils/args_doc.py index 7048a223bb6e..018b881889f0 100644 --- a/src/transformers/utils/args_doc.py +++ b/src/transformers/utils/args_doc.py @@ -673,9 +673,6 @@ class ClassAttrs: _supports_flex_attn = r""" Whether the model's attention implementation supports FlexAttention. """ - _supports_cache_class = r""" - Whether the model supports a `Cache` instance as `past_key_values`. - """ _supports_quantized_cache = r""" Whether the model supports a `QuantoQuantizedCache` instance as `past_key_values`. """ From f1ec0ba988d0185070a0bc4ffb1fe9cb7fcf489a Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 11 Jun 2025 17:26:38 +0200 Subject: [PATCH 25/58] delete `supports_quantized_cache` --- src/transformers/generation/utils.py | 16 +++++++++++----- src/transformers/modeling_utils.py | 4 ---- src/transformers/models/aria/modeling_aria.py | 1 - .../models/aya_vision/modeling_aya_vision.py | 1 - .../models/aya_vision/modular_aya_vision.py | 1 - .../models/bitnet/modeling_bitnet.py | 1 - .../models/blip_2/modeling_blip_2.py | 1 - src/transformers/models/bloom/modeling_bloom.py | 1 - .../models/chameleon/modeling_chameleon.py | 1 - .../models/codegen/modeling_codegen.py | 1 - .../models/cohere/modeling_cohere.py | 1 - .../models/cohere2/modeling_cohere2.py | 1 - src/transformers/models/csm/modeling_csm.py | 1 - src/transformers/models/csm/modular_csm.py | 1 - src/transformers/models/dbrx/modeling_dbrx.py | 1 - .../models/deepseek_v3/modeling_deepseek_v3.py | 1 - .../models/diffllama/modeling_diffllama.py | 1 - src/transformers/models/emu3/modeling_emu3.py | 1 - .../models/falcon/modeling_falcon.py | 1 - .../models/falcon_h1/modeling_falcon_h1.py | 1 - src/transformers/models/gemma/modeling_gemma.py | 1 - .../models/gemma2/modeling_gemma2.py | 1 - .../models/gemma3/modeling_gemma3.py | 1 - src/transformers/models/git/modeling_git.py | 2 -- src/transformers/models/glm/modeling_glm.py | 1 - src/transformers/models/glm4/modeling_glm4.py | 1 - .../models/got_ocr2/modeling_got_ocr2.py | 2 +- .../models/gpt_neo/modeling_gpt_neo.py | 1 - .../models/gpt_neox/modeling_gpt_neox.py | 1 - .../modeling_gpt_neox_japanese.py | 1 - src/transformers/models/gptj/modeling_gptj.py | 1 - .../models/granite/modeling_granite.py | 1 - .../models/granitemoe/modeling_granitemoe.py | 1 - .../modeling_granitemoehybrid.py | 10 ---------- .../granitemoehybrid/modular_granitemoehybrid.py | 9 --------- .../modeling_granitemoeshared.py | 1 - .../models/helium/modeling_helium.py | 1 - .../models/instructblip/modeling_instructblip.py | 2 -- .../modeling_instructblipvideo.py | 2 -- .../models/internvl/modeling_internvl.py | 2 +- src/transformers/models/janus/modeling_janus.py | 1 - src/transformers/models/janus/modular_janus.py | 1 - src/transformers/models/llama/modeling_llama.py | 1 - .../models/llama4/modeling_llama4.py | 1 - src/transformers/models/llava/modeling_llava.py | 2 +- .../models/llava_next/modeling_llava_next.py | 2 +- .../modeling_llava_next_video.py | 2 +- .../llava_onevision/modeling_llava_onevision.py | 2 +- .../models/minimax/modeling_minimax.py | 1 - .../models/minimax/modular_minimax.py | 1 - .../models/mistral/modeling_mistral.py | 1 - .../models/mistral3/modeling_mistral3.py | 2 +- .../models/mixtral/modeling_mixtral.py | 1 - .../models/mllama/modeling_mllama.py | 4 +--- src/transformers/models/mt5/modeling_mt5.py | 1 - .../models/nemotron/modeling_nemotron.py | 1 - src/transformers/models/olmo/modeling_olmo.py | 1 - src/transformers/models/olmo2/modeling_olmo2.py | 1 - src/transformers/models/olmoe/modeling_olmoe.py | 1 - src/transformers/models/opt/modeling_opt.py | 1 - .../models/paligemma/modeling_paligemma.py | 1 - .../models/persimmon/modeling_persimmon.py | 1 - src/transformers/models/phi/modeling_phi.py | 1 - src/transformers/models/phi3/modeling_phi3.py | 1 - .../phi4_multimodal/modeling_phi4_multimodal.py | 1 - .../models/phimoe/modeling_phimoe.py | 1 - src/transformers/models/qwen2/modeling_qwen2.py | 1 - src/transformers/models/qwen3/modeling_qwen3.py | 1 - .../models/qwen3_moe/modeling_qwen3_moe.py | 1 - .../recurrent_gemma/modeling_recurrent_gemma.py | 2 -- .../models/stablelm/modeling_stablelm.py | 2 +- .../models/starcoder2/modeling_starcoder2.py | 1 - src/transformers/models/t5/modeling_t5.py | 1 - .../models/video_llava/modeling_video_llava.py | 2 +- .../models/vipllava/modeling_vipllava.py | 2 +- tests/generation/test_utils.py | 9 ++++++--- tests/models/mllama/test_modeling_mllama.py | 7 +++++++ 77 files changed, 35 insertions(+), 108 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f69223aac3af..4cc6f5ae8d23 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1941,16 +1941,22 @@ def _get_cache( self._cache.reset() return self._cache - def _supports_default_dynamic_cache(self) -> bool: + @classmethod + def _supports_default_dynamic_cache(cls) -> bool: """ Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`. This adds exception for some models like `Jamba` model which uses its own `HybridMambaAttentionDynamicCache` and do not need to initialize the Cache in advance in order to save memory (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed for `HybridMambaAttentionDynamicCache`). """ - return all( - special_model_name not in self.__class__.__name__.lower() - for special_model_name in ["jamba", "zamba", "mamba", "bamba", "reformer", "minimax"] + # NOTE: remove xlnet/reformer when the models is deprecated, it uses `mems` as cache name + return not cls._is_stateful and all( + special_model_name not in cls.__name__.lower() + for special_model_name in [ + "reformer", + "minimax", + "xlnet", + ] ) def _prepare_cache_for_generation( @@ -2037,7 +2043,7 @@ def _prepare_cache_for_generation( model_kwargs=model_kwargs, ) elif generation_config.cache_implementation == "quantized": - if not self._supports_quantized_cache: + if self.config.is_encoder_decoder or not self._supports_default_dynamic_cache(): raise ValueError( "This model does not support the quantized cache. If you want your model to support quantized " "cache, please open an issue and tag @zucchini-nlp." diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 64d5737545c7..b0020d87d6e3 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1968,12 +1968,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi _supports_flex_attn = False # Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`? - _supports_cache_class = False _supports_static_cache = False - # Has support for a `QuantoQuantizedCache` instance as `past_key_values` - _supports_quantized_cache = False - # A tensor parallel plan to be applied to the model when TP is enabled. For # top-level models, this attribute is currently defined in respective model # code. For base models, this attribute comes from diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 4fa0df8d5843..b42dedb46683 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -671,7 +671,6 @@ class AriaPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing) _supports_attention_backend = True diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index f6c3048c491f..d47a55448233 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -95,7 +95,6 @@ class AyaVisionPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = False _supports_static_cache = False _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/aya_vision/modular_aya_vision.py b/src/transformers/models/aya_vision/modular_aya_vision.py index 533a7f444472..fa698fc241ec 100644 --- a/src/transformers/models/aya_vision/modular_aya_vision.py +++ b/src/transformers/models/aya_vision/modular_aya_vision.py @@ -89,7 +89,6 @@ def pixel_shuffle(self, image_features): # B, S, D class AyaVisionPreTrainedModel(LlavaPreTrainedModel): - _supports_quantized_cache = False _supports_static_cache = False def _init_weights(self, module): diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index d06b5e24596b..35d5f26920e0 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -322,7 +322,6 @@ class BitNetPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 11cb72eccafa..6d13e88792ed 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -1876,7 +1876,6 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): main_input_name = "pixel_values" _supports_static_cache = True - _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) _keep_in_fp32_modules = ["query_tokens", "qformer"] def __init__(self, config: Blip2Config): diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 2e88c1ebc3ac..0764a102c3af 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -448,7 +448,6 @@ class BloomPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_static_cache = True - _supports_quantized_cache = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index a142dba796b9..23b50cd8fd4f 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -826,7 +826,6 @@ class ChameleonPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True _supports_static_cache = True _supports_param_buffer_assignment = False diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 719b93945d76..5fac0fd1953e 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -299,7 +299,6 @@ class CodeGenPreTrainedModel(PreTrainedModel): _no_split_modules = ["CodeGenBlock"] _skip_keys_device_placement = "past_key_values" - _supports_quantized_cache = True _supports_static_cache = True def __init__(self, *inputs, **kwargs): diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 3c9a6a20fbeb..295f2e7a5bea 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -359,7 +359,6 @@ class CoherePreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 7c3391b66757..22eaee98fd07 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -338,7 +338,6 @@ class Cohere2PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index aff8afc14df7..ebb03bfc9052 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -121,7 +121,6 @@ class CsmPreTrainedModel(PreTrainedModel): # does not because of Mimi codec model # _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/csm/modular_csm.py b/src/transformers/models/csm/modular_csm.py index 1e73dc1c364f..836f531adeb3 100644 --- a/src/transformers/models/csm/modular_csm.py +++ b/src/transformers/models/csm/modular_csm.py @@ -121,7 +121,6 @@ class CsmPreTrainedModel(PreTrainedModel): # does not because of Mimi codec model # _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index a35d4cf70af4..e8fa397d3067 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -809,7 +809,6 @@ class DbrxPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) def _init_weights(self, module: nn.Module): diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index c69d0a065bf6..aafe0336c8a1 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -501,7 +501,6 @@ class DeepseekV3PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index fef158efff93..a2888a29a7f5 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -555,7 +555,6 @@ class DiffLlamaPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = False - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = False diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 47fa6411e923..366c82e0fd04 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1138,7 +1138,6 @@ class Emu3PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True _supports_static_cache = True _supports_param_buffer_assignment = False diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index be4c3187582c..cada0fac99b9 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -658,7 +658,6 @@ class FalconPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True _supports_static_cache = True def __init__(self, *inputs, **kwargs): diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index b8d70deafbd8..c97b13942273 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -1161,7 +1161,6 @@ class FalconH1PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True - # Note: only supports FalconHybridMambaAttentionDynamicCache _is_stateful = True def _init_weights(self, module): diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index c83b42c3a6cc..33a13c49ef9a 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -322,7 +322,6 @@ class GemmaPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index bb389f6626a5..d35eb78762c6 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -343,7 +343,6 @@ class Gemma2PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index b0099bbb8de5..edc14fb023bb 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -446,7 +446,6 @@ class Gemma3PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 519918e24d8d..27526394c783 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -498,8 +498,6 @@ class GitPreTrainedModel(PreTrainedModel): base_model_prefix = "git" supports_gradient_checkpointing = True - _supports_quantized_cache = True - def _init_weights(self, module): """Initialize the weights""" if isinstance(module, GitVisionEmbeddings): diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 996150ee881a..765074ecf425 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -339,7 +339,6 @@ class GlmPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index dab63c4cb0f9..79059f4fcf26 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -347,7 +347,6 @@ class Glm4PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 644ecc2e8fcb..f66a7079fcbf 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -591,7 +591,7 @@ class GotOcr2PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index fae438f9c16b..59f3bc731cb5 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -486,7 +486,6 @@ class GPTNeoPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True - _supports_quantized_cache = True _supports_static_cache = False # TODO: needs a HybridCache def __init__(self, *inputs, **kwargs): diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 2ff9b7603383..4fac6f079915 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -295,7 +295,6 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True _keys_to_ignore_on_load_unexpected = [r"attention.bias", r"attention.masked_bias"] diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index dbfd947af4a3..7bc9e9f1dd01 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -48,7 +48,6 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTNeoXJapaneseLayer"] _skip_keys_device_placement = "past_key_values" - _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 7454ec3a497d..2654ab8f067f 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -484,7 +484,6 @@ class GPTJPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True - _supports_quantized_cache = True _supports_static_cache = True _supports_param_buffer_assignment = False diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 77fb9811fb1d..f642e74e4829 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -309,7 +309,6 @@ class GranitePreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 3953d3064e88..e1f27707ee2e 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -598,7 +598,6 @@ class GraniteMoePreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) def _init_weights(self, module): diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 8bbdbdba3760..92de8ea1bef6 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -1160,7 +1160,6 @@ class GraniteMoeHybridPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _is_stateful = True @@ -1794,14 +1793,5 @@ def prepare_inputs_for_generation( ) return model_inputs - def _supports_default_dynamic_cache(self) -> bool: - """ - Function overwritten as this class uses its own `HybridMambaAttentionDynamicCache` - and do not need to initialize the Cache in advance in order to save memory - (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed - for `HybridMambaAttentionDynamicCache`). - """ - return False - __all__ = ["GraniteMoeHybridForCausalLM", "GraniteMoeHybridModel", "GraniteMoeHybridPreTrainedModel"] diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index f9dc0ec3ea6a..e2fbe16bdb1e 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -381,14 +381,5 @@ def prepare_inputs_for_generation( ) return model_inputs - def _supports_default_dynamic_cache(self) -> bool: - """ - Function overwritten as this class uses its own `HybridMambaAttentionDynamicCache` - and do not need to initialize the Cache in advance in order to save memory - (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed - for `HybridMambaAttentionDynamicCache`). - """ - return False - __all__ = ["GraniteMoeHybridForCausalLM", "GraniteMoeHybridModel", "GraniteMoeHybridPreTrainedModel"] diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index f47ab9107413..3cfbfc4169b4 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -512,7 +512,6 @@ class GraniteMoeSharedPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) def _init_weights(self, module): diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 5c5a632c9066..7d67861be936 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -324,7 +324,6 @@ class HeliumPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index bf654690cc09..9f4055c25fb9 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -337,7 +337,6 @@ class InstructBlipPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_static_cache = True - _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) _no_split_modules = [ "InstructBlipQFormerEmbeddings", @@ -1377,7 +1376,6 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati main_input_name = "pixel_values" _supports_static_cache = True - _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8 def __init__(self, config: InstructBlipConfig): diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index 2d4090754824..91420deaa38b 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -854,7 +854,6 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_static_cache = True - _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) _no_split_modules = [ "InstructBlipVideoQFormerEmbeddings", @@ -1383,7 +1382,6 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel main_input_name = "pixel_values" _supports_static_cache = True - _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8 def __init__(self, config: InstructBlipVideoConfig): diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index 1f12c267a072..0acfbf13b39b 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -537,7 +537,7 @@ class InternVLPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index 7a515c0efdf7..7544527733f7 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -62,7 +62,6 @@ class JanusPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True _supports_static_cache = True _supports_param_buffer_assignment = False diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 737cc7187f2b..8bc1c16c83ed 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -384,7 +384,6 @@ class JanusPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True _supports_static_cache = True _supports_param_buffer_assignment = False diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 41245cabf49b..8a4514e3c23b 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -328,7 +328,6 @@ class LlamaPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 50b52f232590..ae531644155e 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -436,7 +436,6 @@ class Llama4PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 70fcbf7a8127..3e6fc9312169 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -139,7 +139,7 @@ class LlavaPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 3f56da38ddac..291a02671554 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -250,7 +250,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 81c791d5c22c..6d3f31886b43 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -193,7 +193,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 96fe00778783..27e01b3d4724 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -306,7 +306,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index fb450cfd6a1e..2b19adf706ef 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -593,7 +593,6 @@ class MiniMaxPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True # Note: only supports MiniMaxCache - _supports_quantized_cache = False _supports_static_cache = False _supports_attention_backend = True diff --git a/src/transformers/models/minimax/modular_minimax.py b/src/transformers/models/minimax/modular_minimax.py index 186c9961e1b2..c0d62c7513c9 100644 --- a/src/transformers/models/minimax/modular_minimax.py +++ b/src/transformers/models/minimax/modular_minimax.py @@ -474,7 +474,6 @@ def forward( class MiniMaxPreTrainedModel(MixtralPreTrainedModel): # Note: only supports MiniMaxCache _supports_static_cache = False - _supports_quantized_cache = False class MiniMaxModel(MixtralModel): diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index cfb534256767..93e406d83714 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -266,7 +266,6 @@ class MistralPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index 70ecc3b38fdf..69029e30fa5f 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -204,7 +204,7 @@ class Mistral3PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 9c66c2f6fc6b..0fdb2b731d92 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -421,7 +421,6 @@ class MixtralPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 78a925fe5f85..939aa226cc58 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -859,7 +859,7 @@ class MllamaPreTrainedModel(PreTrainedModel): _supports_static_cache = False # static cache cannot have different shapes for each layer _supports_sdpa = True _supports_flash_attn_2 = True - _supports_quantized_cache = True + _supports_flex_attn = True _supports_attention_backend = True @@ -1616,7 +1616,6 @@ def forward( ) class MllamaModel(MllamaPreTrainedModel): _checkpoint_conversion_mapping = {"language_model.model": "language_model"} - _supports_quantized_cache = False # quant cache not supported in encoder-decoder setting def __init__(self, config: MllamaConfig): super().__init__(config) @@ -1770,7 +1769,6 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): "^multi_modal_projector": "model.multi_modal_projector", "^language_model.lm_head": "lm_head", } - _supports_quantized_cache = False # quant cache not supported in encoder-decoder setting _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: MllamaConfig): diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index a4764ef0e9ea..4c83a55f88e1 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -756,7 +756,6 @@ class MT5PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True - _supports_quantized_cache = False # enc-dec models don't support yet _supports_static_cache = True _no_split_modules = ["MT5Block"] diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index f6f7fb7f6088..c506a6ed5448 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -578,7 +578,6 @@ class NemotronPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 4c413f04c659..e7e816caf95d 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -305,7 +305,6 @@ class OlmoPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 5a4b5aaccfb0..658c77212144 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -309,7 +309,6 @@ class Olmo2PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 4b274ea95c71..cca11051bbf0 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -707,7 +707,6 @@ class OlmoePreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) def _init_weights(self, module): diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 0fcd9bff7c90..d695229611bc 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -319,7 +319,6 @@ class OPTPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 0192ba74502f..ec1863c78879 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -127,7 +127,6 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): _no_split_modules = ["PaliGemmaMultiModalProjector"] _skip_keys_device_placement = "past_key_values" - _supports_quantized_cache = True _supports_static_cache = True _supports_flash_attn_2 = True _supports_sdpa = True diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index ac8bcb766cc9..79147015e357 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -390,7 +390,6 @@ class PersimmonPreTrainedModel(PreTrainedModel): _no_split_modules = ["PersimmonDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_quantized_cache = True _supports_static_cache = True _supports_sdpa = True _supports_flash_attn_2 = True diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index c6fdddfa658a..e13441982de4 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -299,7 +299,6 @@ class PhiPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 5c7ec6fd4c01..6e68b16cb119 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -320,7 +320,6 @@ class Phi3PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True _version = "0.0.5" diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index d3c143c2e579..4bb82f16bb3b 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -1626,7 +1626,6 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True _version = "0.0.5" diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index e23dadb7fb5d..f06e886266ac 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -885,7 +885,6 @@ class PhimoePreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) def _init_weights(self, module): diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 2a1095e01d27..e7742432ecd8 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -270,7 +270,6 @@ class Qwen2PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 93723de174f3..7d89200c15c3 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -296,7 +296,6 @@ class Qwen3PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index f25cc6609ae8..fc28f68bb716 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -426,7 +426,6 @@ class Qwen3MoePreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 97818aff0aeb..49da2bb980c4 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -515,8 +515,6 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = False _supports_sdpa = False # we can't compare with eager for now - _supports_quantized_cache = True - def _init_weights(self, module): std = math.sqrt(self.config.w_init_variance_scale / self.config.conv1d_width) if isinstance(module, nn.Conv1d): diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 57442a7a44dc..686fa7186111 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -621,7 +621,7 @@ class StableLmPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 15b79b795c17..6472ae149c61 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -303,7 +303,6 @@ class Starcoder2PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index eaeb737966c0..a094a850d7d7 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -772,7 +772,6 @@ class T5PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True - _supports_quantized_cache = False # enc-dec models don't support yet _supports_static_cache = True _no_split_modules = ["T5Block"] diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 37130a2466f7..4c6db9ada70f 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -153,7 +153,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 6b303c47ed77..c1bc2f8807fd 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -140,7 +140,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index e79efa05f1cf..3582a06b9021 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2038,12 +2038,15 @@ def test_generate_with_static_cache(self): @pytest.mark.generate def test_generate_with_quant_cache(self): for model_class in self.all_generative_model_classes: - if not model_class._supports_quantized_cache: + config, inputs_dict = self.prepare_config_and_inputs_for_generate() + + if ( + config.get_text_config(decoder=True).is_encoder_decoder + or not model_class._supports_default_dynamic_cache() + ): self.skipTest(reason="This model does not support the quantized cache format") - config, inputs_dict = self.prepare_config_and_inputs_for_generate() config.is_decoder = True - model = model_class(config).to(torch_device).eval() generation_kwargs = { "max_new_tokens": 5, diff --git a/tests/models/mllama/test_modeling_mllama.py b/tests/models/mllama/test_modeling_mllama.py index 41ae39681698..6b3113ded24c 100644 --- a/tests/models/mllama/test_modeling_mllama.py +++ b/tests/models/mllama/test_modeling_mllama.py @@ -34,6 +34,7 @@ Expectations, cleanup, require_bitsandbytes, + require_optimum_quanto, require_read_token, require_torch, require_torch_accelerator, @@ -344,6 +345,12 @@ def _check_attentions_for_generate( self.assertListEqual([layer_attention.shape for layer_attention in iter_attentions], expected_shapes) + @require_optimum_quanto + @pytest.mark.generate + @unittest.skip("Mllama is actually an encoder decoder cache and thus can't supports quant cache") + def test_generate_with_quant_cache(self): + pass + @unittest.skip("For some unknown reasons the tests fails in CrossAttention layer when doing torch.sdpa(). ") def test_sdpa_can_compile_dynamic(self): pass From 8f5d8a03a346e2dee22060b813f145a63e1dd065 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 11 Jun 2025 18:15:59 +0200 Subject: [PATCH 26/58] fix failing tests --- .../models/reformer/modeling_reformer.py | 2 +- tests/test_modeling_common.py | 21 +++---------------- 2 files changed, 4 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index e3082f1c063e..fe4187a921ce 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -2379,7 +2379,7 @@ def _reorder_cache(self, past_key_values, beam_idx): reord_past_buckets_states = [] for layer_past in past_key_values: # buckets - if layer_past[0].numel() != 0: + if layer_past[0] is not None: reord_buckets = layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)) else: reord_buckets = None diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index bc62d56894e1..78fd7d013f59 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -524,7 +524,7 @@ def test_can_init_all_missing_weights(self): # For now, skip everything older than 2025 and "important models" (too much models to patch otherwise) # Use `supports_cache_class` as a proxy to judge "important" models in order to prioritize them # TODO: relax this as we patch more and more models - if addition_year < 2025 and not model_class._supports_cache_class: + if addition_year < 2025: self.skipTest(reason=f"{model_class} is not a priorited model for now.") # Monkey patch the method to add a seed (we do it on PreTrainedModel._initialize_weights, which wraps @@ -1327,18 +1327,7 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa head_dim = model.config.hidden_size // model.config.num_attention_heads cache_shape = (batch_size, num_heads, 0, head_dim) - empty_pkv = tuple( - ( - torch.rand(cache_shape, dtype=torch.float, device=torch_device), - torch.rand(cache_shape, dtype=torch.float, device=torch_device), - ) - for i in range(model.config.num_hidden_layers) - ) - empty_pkv = ( - DynamicCache.from_legacy_cache(empty_pkv) - if model_class._supports_cache_class - else empty_pkv - ) + empty_pkv = DynamicCache() cache_length = 9 cache_shape = (batch_size, num_heads, cache_length, head_dim) @@ -1349,11 +1338,7 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa ) for i in range(model.config.num_hidden_layers) ) - non_empty_pkv = ( - DynamicCache.from_legacy_cache(non_empty_pkv) - if model_class._supports_cache_class - else non_empty_pkv - ) + non_empty_pkv = DynamicCache.from_legacy_cache(non_empty_pkv) inps = copy.deepcopy(inputs_to_test[0]) From 08ad1b0773dadcd14a0240ae9d32057ced91ea1e Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 12 Jun 2025 09:08:19 +0200 Subject: [PATCH 27/58] fix copies --- src/transformers/models/aria/modeling_aria.py | 1 - src/transformers/models/falcon_h1/modular_falcon_h1.py | 1 - src/transformers/models/mixtral/modeling_mixtral.py | 1 - src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py | 1 - src/transformers/models/qwen3_moe/modeling_qwen3_moe.py | 1 - 5 files changed, 5 deletions(-) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index b42dedb46683..889eb86a7918 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -670,7 +670,6 @@ class AriaPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing) _supports_attention_backend = True diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index fb8fefbd5ff9..e8aa0bad1b5a 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -936,7 +936,6 @@ class FalconH1PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True - # Note: only supports FalconHybridMambaAttentionDynamicCache _is_stateful = True def _init_weights(self, module): diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 0fdb2b731d92..32a848aad123 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -420,7 +420,6 @@ class MixtralPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index b9afe03bfe71..888552fe9f3d 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -98,7 +98,6 @@ class Qwen2_5OmniPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True - _supports_static_cache = False _supports_attention_backend = True diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index fc28f68bb716..99ff7b2c1e14 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -425,7 +425,6 @@ class Qwen3MoePreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True From dfdf50bd02e54324b9c0a08fa948e766464fc00b Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 12 Jun 2025 09:47:37 +0200 Subject: [PATCH 28/58] some minor clean up --- src/transformers/generation/utils.py | 8 ++--- src/transformers/modeling_utils.py | 2 +- tests/generation/test_utils.py | 48 +++++++--------------------- 3 files changed, 17 insertions(+), 41 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 4cc6f5ae8d23..e10f737fc723 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1949,7 +1949,7 @@ def _supports_default_dynamic_cache(cls) -> bool: and do not need to initialize the Cache in advance in order to save memory (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed for `HybridMambaAttentionDynamicCache`). """ - # NOTE: remove xlnet/reformer when the models is deprecated, it uses `mems` as cache name + # NOTE: remove xlnet/reformer when the models are deprecated, non-standard model architecture/cache name return not cls._is_stateful and all( special_model_name not in cls.__name__.lower() for special_model_name in [ @@ -4132,7 +4132,7 @@ def _beam_search( # beam search as a whole (as opposed to individual beams, i.e. `stopping_criteria`) # pluck the cache from the beam indices that will be used in the next iteration - # NOTE: we need to check if `self._reorder_cache` for special models like RAG, RecurrentGemma etc. + # NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc. if model_kwargs.get("past_key_values", None) is not None: beam_idx = self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len]) if hasattr(self, "_reorder_cache"): @@ -4435,7 +4435,7 @@ def _group_beam_search( # (that way the memory peak does not include outputs.logits) del outputs - # NOTE: we need to check if `self._reorder_cache` for special models like RAG, RecurrentGemma etc. + # NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc. if model_kwargs.get("past_key_values", None) is not None: if hasattr(self, "_reorder_cache"): model_kwargs["past_key_values"] = self._reorder_cache( @@ -4676,7 +4676,7 @@ def _constrained_beam_search( # (that way the memory peak does not include outputs.logits) del outputs - # NOTE: we need to check if `self._reorder_cache` for special models like RAG, RecurrentGemma etc. + # NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc. if model_kwargs.get("past_key_values", None) is not None: if hasattr(self, "_reorder_cache"): model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b0020d87d6e3..8b8eee33b377 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1967,7 +1967,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi # Flex Attention support _supports_flex_attn = False - # Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`? + # Has support `torch.compile(fullgraph=True)` _supports_static_cache = False # A tensor parallel plan to be applied to the model when TP is enabled. For diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 3582a06b9021..6bab4126bdb4 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1012,7 +1012,7 @@ def test_contrastive_generate(self): self.skipTest(reason="Stateful models don't support contrastive search generation") # won't fix: FSMT and Reformer have a different cache variable type (and format). - if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]): self.skipTest(reason="Won't fix: old model with different cache format") config, inputs_dict = self.prepare_config_and_inputs_for_generate() @@ -1041,7 +1041,7 @@ def test_contrastive_generate_dict_outputs_use_cache(self): self.skipTest(reason="Stateful models don't support contrastive search generation") # won't fix: FSMT and Reformer have a different cache variable type (and format). - if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]): self.skipTest(reason="Won't fix: old model with different cache format") config, inputs_dict = self.prepare_config_and_inputs_for_generate() @@ -1081,10 +1081,8 @@ def test_contrastive_generate_low_memory(self): if model_class._is_stateful: self.skipTest(reason="Stateful models don't support contrastive search generation") - if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]): + if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]): self.skipTest(reason="Won't fix: old model with different cache format") - if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]): - self.skipTest(reason="TODO: fix me") config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) @@ -1123,22 +1121,16 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): for model_class in self.all_generative_model_classes: if model_class._is_stateful: self.skipTest(reason="Stateful models don't support assisted generation") - if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]): self.skipTest(reason="Won't fix: old model with different cache format") if any( model_name in model_class.__name__.lower() for model_name in [ - "bigbirdpegasus", - "led", - "mega", "moshi", - "speech2text", "git", "prophetnet", - "seamlessm4t", - "clvp", - "mllama", # special cache sizes - "blip2", # overridden `generate()` + "mllama", # special cache sizes + "blip2", # overridden `generate()` all BLIP models "instructblip", "instructblipvideo", ] @@ -1207,23 +1199,16 @@ def test_prompt_lookup_decoding_matches_greedy_search(self): for model_class in self.all_generative_model_classes: if model_class._is_stateful: self.skipTest(reason="Stateful models don't support assisted generation") - if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]): self.skipTest(reason="Won't fix: old model with different cache format") if any( model_name in model_class.__name__.lower() for model_name in [ - "bigbirdpegasus", - "led", - "mega", "moshi", - "speech2text", "git", "prophetnet", - "seamlessm4t", - "clvp", - "fuyu", - "mllama", # special cache sizes - "blip2", # overridden `generate()` + "mllama", # special cache sizes + "blip2", # overridden `generate()` for all BLIP models "instructblip", "instructblipvideo", *VLM_CLASS_NAMES, # shouldn't suggest image tokens @@ -1335,22 +1320,16 @@ def test_assisted_decoding_sample(self): for model_class in self.all_generative_model_classes: if model_class._is_stateful: self.skipTest(reason="Stateful models don't support assisted generation") - if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]): self.skipTest(reason="Won't fix: old model with different cache format") if any( model_name in model_class.__name__.lower() for model_name in [ - "bigbirdpegasus", - "led", - "mega", "moshi", - "speech2text", "git", "prophetnet", - "seamlessm4t", - "clvp", - "mllama", # special cache sizes - "blip2", # overridden `generate()` + "mllama", # special cache sizes + "blip2", # overridden `generate()` for all BLIP models "instructblip", "instructblipvideo", ] @@ -2468,10 +2447,7 @@ def _check_generate_outputs(self, output, config, use_cache=False, num_return_se # standard cache format (e.g.mamba architecture ) models_without_standard_cache = ( "bamba", - "ctrl", - "fsmt", "granitemoehybrid", - "mega", "reformer", "jamba", "mamba", From e9a281f02b5e8833b9d60e34ac595b12a708af2a Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 12 Jun 2025 09:49:35 +0200 Subject: [PATCH 29/58] style --- tests/generation/test_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6bab4126bdb4..e8940be40c4c 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1129,7 +1129,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): "moshi", "git", "prophetnet", - "mllama", # special cache sizes + "mllama", # special cache sizes "blip2", # overridden `generate()` all BLIP models "instructblip", "instructblipvideo", @@ -1207,7 +1207,7 @@ def test_prompt_lookup_decoding_matches_greedy_search(self): "moshi", "git", "prophetnet", - "mllama", # special cache sizes + "mllama", # special cache sizes "blip2", # overridden `generate()` for all BLIP models "instructblip", "instructblipvideo", @@ -1328,7 +1328,7 @@ def test_assisted_decoding_sample(self): "moshi", "git", "prophetnet", - "mllama", # special cache sizes + "mllama", # special cache sizes "blip2", # overridden `generate()` for all BLIP models "instructblip", "instructblipvideo", From e1a3fc4e2ae1d664b9b17709d27da8b00b6cfa3d Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 1 Jul 2025 17:19:27 +0200 Subject: [PATCH 30/58] style --- src/transformers/models/align/modeling_align.py | 1 - src/transformers/models/bros/modeling_bros.py | 5 ++--- .../models/chinese_clip/modeling_chinese_clip.py | 10 +--------- src/transformers/models/esm/modeling_esm.py | 5 ++--- src/transformers/models/fsmt/modeling_fsmt.py | 2 +- .../models/gpt_bigcode/modeling_gpt_bigcode.py | 9 ++++----- .../models/kosmos2/modeling_kosmos2.py | 7 ------- src/transformers/models/mpt/modeling_mpt.py | 2 +- .../models/reformer/modeling_reformer.py | 16 ++++++++-------- .../models/splinter/modeling_splinter.py | 3 +-- 10 files changed, 20 insertions(+), 40 deletions(-) diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 8ca671bc4a1a..a2152a7fc981 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -23,7 +23,6 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index 2e265046fc99..c864ab078971 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -357,7 +357,6 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, - cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -399,13 +398,13 @@ def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = BrosAttention(config, layer_idx) + self.attention = BrosAttention(config) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise Exception(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = BrosAttention(config, layer_idx) + self.crossattention = BrosAttention(config) self.intermediate = BrosIntermediate(config) self.output = BrosOutput(config) diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 27f7638f170f..301e51965648 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -926,16 +926,8 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - past_key_values_length = 0 - if past_key_values is not None: - past_key_values_length = ( - past_key_values[0][0].shape[-2] - if not isinstance(past_key_values, Cache) - else past_key_values.get_seq_length() - ) - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + attention_mask = torch.ones(((batch_size, seq_length)), device=device) if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 53ecaa1b08fb..452cce33e506 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -23,7 +23,6 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -586,13 +585,13 @@ def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = EsmAttention(config, layer_idx) + self.attention = EsmAttention(config) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = EsmAttention(config, layer_idx) + self.crossattention = EsmAttention(config) self.intermediate = EsmIntermediate(config) self.output = EsmOutput(config) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 1554641c6ec6..8ed9feb35069 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -28,7 +28,7 @@ """PyTorch Fairseq model, ported from https://github.com/pytorch/fairseq/tree/master/examples/wmt19""" import math -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch from torch import Tensor, nn diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 30d94a5650b0..75bf1426a11c 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -14,7 +14,7 @@ """PyTorch GPTBigCode model.""" import math -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Union import torch import torch.utils.checkpoint @@ -24,9 +24,8 @@ from ...activations import ACT2FN from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin -from ...masking_utils import AttentionMaskConverter, create_causal_mask -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available -from ...modeling_layers import GradientCheckpointingLayer +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import is_flash_attn_available from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -574,7 +573,7 @@ def forward( outputs = block( hidden_states, past_key_values, - attention_mask, + causal_mask, head_mask[i], encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index e49aad8a6321..18bd1e885470 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -1489,13 +1489,6 @@ def prepare_inputs_for_generation( ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - # Kosmos2 has offset for position ids, so we need to create them correctly - position_ids = create_position_ids_from_input_ids( - input_ids, - padding_idx=self.config.pad_token_id, - past_key_values_length=0, - ) - # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore if cache_position[0] != 0: image_embeds = None diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index cadd4e5eac14..fc608fbe2eb7 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -310,7 +310,7 @@ def set_input_embeddings(self, new_embeddings: torch.Tensor): def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[tuple[Tuple[torch.Tensor, torch.Tensor], ...], Cache]] = None, + past_key_values: Optional[Union[tuple[tuple[torch.Tensor, torch.Tensor], ...], Cache]] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 9eefad30a351..d0a80755b7d7 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -20,7 +20,7 @@ from dataclasses import dataclass from functools import reduce from operator import mul -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Iterable, Optional, Union import numpy as np import torch @@ -69,15 +69,15 @@ class ReformerDynamicCache(DynamicCache): def __init__(self, _distributed_cache_data: Optional[Iterable] = None) -> None: super().__init__() self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen - self.buckets_cache: List[torch.Tensor] = [] - self.states_cache: List[torch.Tensor] = [] + self.buckets_cache: list[torch.Tensor] = [] + self.states_cache: list[torch.Tensor] = [] if _distributed_cache_data is not None: for buckets, states in _distributed_cache_data: self.buckets_cache.append(buckets) self.states_cache.append(states) - def __getitem__(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: """ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the sequence length. @@ -107,8 +107,8 @@ def update( buckets: torch.Tensor, states: torch.Tensor, layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. @@ -150,7 +150,7 @@ def update( def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return None - def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]: """Converts the `ReformerDynamicCache` instance into the its equivalent in the legacy cache format. Used for backward compatibility.""" legacy_cache = () @@ -162,7 +162,7 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: @classmethod def from_legacy_cache( - cls, past_buckets_states: Optional[Tuple[Tuple[torch.FloatTensor, torch.FloatTensor]]] = None + cls, past_buckets_states: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None ) -> "ReformerDynamicCache": """Converts a cache in the legacy cache format into an equivalent `ReformerDynamicCache`. Used for backward compatibility.""" diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 38070bdb6ab7..20df061a90ac 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -23,9 +23,8 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput +from ...modeling_outputs import BaseModelOutput, ModelOutput, QuestionAnsweringModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( From 3190a9e1e2060c30c960a871d659207bdfdc111c Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 2 Jul 2025 11:45:30 +0200 Subject: [PATCH 31/58] fix copies --- .../models/arcee/modeling_arcee.py | 3 +- .../modeling_bert_generation.py | 35 ++++++------------- .../models/dots1/modeling_dots1.py | 3 +- .../models/gemma3n/modeling_gemma3n.py | 3 +- .../models/glm4v/modeling_glm4v.py | 2 +- .../modeling_kyutai_speech_to_text.py | 2 +- .../models/smollm3/modeling_smollm3.py | 3 +- .../models/t5gemma/modeling_t5gemma.py | 3 +- 8 files changed, 18 insertions(+), 36 deletions(-) diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index da233918123c..71570ba9d3c2 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -325,8 +325,7 @@ class ArceePreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index aa591ed652e1..9b7219f38aef 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -362,7 +362,7 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->BertGeneration -class BertEncoder(nn.Module): +class BertGenerationEncoder(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() self.config = config @@ -411,29 +411,16 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_values, - output_attentions, - cache_position, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_values, - output_attentions, - cache_position, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index 58b805cca613..94efc5894b22 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -428,8 +428,7 @@ class Dots1PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 3a4995610d4b..b338817dfac4 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1489,8 +1489,7 @@ class Gemma3nPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 4148dfd10ac7..e038d1a196c1 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -385,7 +385,7 @@ class Glm4vPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True + _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py index 5abc0bd3fc01..803b8bc117a4 100644 --- a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py @@ -123,7 +123,7 @@ class KyutaiSpeechToTextPreTrainedModel(PreTrainedModel): _no_split_modules = ["KyutaiSpeechToTextDecoderLayer", "MimiTransformerLayer"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True + main_input_name = "input_ids" def _init_weights(self, module): diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index 30b566be3e6a..f543ca78dfd9 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -230,8 +230,7 @@ class SmolLM3PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index feccf6d7d9fd..01ad061fcb91 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -565,8 +565,7 @@ class T5GemmaPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True From 2f942f86151c8dd970a45195b2ffdd71d8d2c012 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 7 Jul 2025 11:38:52 +0200 Subject: [PATCH 32/58] fix tests --- .../models/bert_generation/modeling_bert_generation.py | 3 +-- src/transformers/models/bros/modeling_bros.py | 2 +- src/transformers/models/esm/modeling_esm.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 9b7219f38aef..df2bebe11672 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -361,8 +361,7 @@ def feed_forward_chunk(self, attention_output): return layer_output -# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->BertGeneration -class BertGenerationEncoder(nn.Module): +class BertEncoder(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() self.config = config diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index c864ab078971..0677038e7d4b 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -476,7 +476,7 @@ class BrosEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.layer = nn.ModuleList([BrosLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([BrosLayer(config) for _ in range(config.num_hidden_layers)]) @deprecate_kwarg("past_key_values", version="4.54.0") @deprecate_kwarg("use_cache", version="4.54.0") diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 452cce33e506..e46d5207e7b2 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -660,7 +660,7 @@ class EsmEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.layer = nn.ModuleList([EsmLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)]) self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.gradient_checkpointing = False From ccdd784f9487eb65af8837e22c7b4be992bc388f Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 7 Jul 2025 11:54:59 +0200 Subject: [PATCH 33/58] fix copies --- src/transformers/models/got_ocr2/modeling_got_ocr2.py | 2 ++ src/transformers/models/moonshine/modeling_moonshine.py | 8 ++++++++ src/transformers/models/smollm3/modeling_smollm3.py | 1 + 3 files changed, 11 insertions(+) diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 89288784ffe0..358064d5d0df 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -279,8 +279,10 @@ class GotOcr2PreTrainedModel(PreTrainedModel): base_model_prefix = "" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True _supports_sdpa = True + _supports_static_cache = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 4ed7ef64a8a3..21ae3c757b0f 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -926,6 +926,9 @@ def forward( `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into `input_values`, the [`AutoFeatureExtractor`] should be used for padding and conversion into a tensor of type `torch.FloatTensor`. + decoder_position_ids (): + + Example: ```python @@ -1043,6 +1046,11 @@ def forward( `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into `input_values`, the [`AutoFeatureExtractor`] should be used for padding and conversion into a tensor of type `torch.FloatTensor`. + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`): + Indices of positions of each input sequence tokens in the position embeddings. + Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`. + + [What are position IDs?](../glossary#position-ids) Example: diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index e2e033993f6b..27da17934b22 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -292,6 +292,7 @@ class SmolLM3PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { From 63f1bd3f4c9973a8d845f36d1ef76f293ff671c2 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 7 Jul 2025 12:01:38 +0200 Subject: [PATCH 34/58] create causal mask now needs positions? --- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 75bf1426a11c..2592f1948f0a 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -521,6 +521,7 @@ def forward( input_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, + position_ids=position_ids, past_key_values=past_key_values, ) From 2dc3b01a1745096e7d744c85317aa696f43e41a9 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 7 Jul 2025 12:57:34 +0200 Subject: [PATCH 35/58] fixc copies --- src/transformers/models/moonshine/modeling_moonshine.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 21ae3c757b0f..4ed7ef64a8a3 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -926,9 +926,6 @@ def forward( `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into `input_values`, the [`AutoFeatureExtractor`] should be used for padding and conversion into a tensor of type `torch.FloatTensor`. - decoder_position_ids (): - - Example: ```python @@ -1046,11 +1043,6 @@ def forward( `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into `input_values`, the [`AutoFeatureExtractor`] should be used for padding and conversion into a tensor of type `torch.FloatTensor`. - decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`): - Indices of positions of each input sequence tokens in the position embeddings. - Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`. - - [What are position IDs?](../glossary#position-ids) Example: From e945d2f2f3a15344134eaccf00381ee28f9f5096 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 10 Jul 2025 07:32:50 +0200 Subject: [PATCH 36/58] style --- src/transformers/cache_utils.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 44330cc7cdff..0b7052c63485 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1428,19 +1428,6 @@ def __len__(self): """ return len(self.self_attention_cache) - def __iter__(self): - """ - Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over - keys and values - """ - for layer_idx in range(len(self)): - yield ( - self.self_attention_cache.key_cache[layer_idx], - self.self_attention_cache.value_cache[layer_idx], - self.cross_attention_cache.key_cache[layer_idx], - self.cross_attention_cache.value_cache[layer_idx], - ) - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor]]: """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format.""" legacy_cache = () From 8a7a05bda207ccde1b8f1e2f36e80326f23cdd85 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Thu, 10 Jul 2025 07:33:54 +0200 Subject: [PATCH 37/58] Update tests/test_modeling_common.py Co-authored-by: Joao Gante --- tests/test_modeling_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 12d205eafff1..c3afe32eb2bf 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -853,7 +853,7 @@ def test_can_init_all_missing_weights(self): addition_year = int(match_object.group(1)) for model_class in self.all_model_classes: - # For now, skip everything older than 2025 and "important models" (too much models to patch otherwise) + # For now, skip everything older than 2024 and "important models" (too much models to patch otherwise) # Use `supports_cache_class` as a proxy to judge "important" models in order to prioritize them # TODO: relax this as we patch more and more models if addition_year < 2024: From ce665c0d38e17f3f0cb1a0b6c54bd151c9249f23 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 10 Jul 2025 07:53:20 +0200 Subject: [PATCH 38/58] clean-up of non-generative model after merging main --- src/transformers/models/align/modeling_align.py | 2 +- src/transformers/models/altclip/modeling_altclip.py | 12 ++++-------- .../models/chinese_clip/modeling_chinese_clip.py | 2 +- src/transformers/models/clap/modeling_clap.py | 2 +- .../models/deepseek_v2/modeling_deepseek_v2.py | 3 +-- .../models/layoutlm/modeling_layoutlm.py | 2 +- .../models/markuplm/modeling_markuplm.py | 2 +- .../models/splinter/modeling_splinter.py | 2 +- 8 files changed, 11 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index a2152a7fc981..da015bf7dd29 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -807,7 +807,7 @@ def feed_forward_chunk(self, attention_output): class AlignTextEncoder(nn.Module): - def __init__(self, config, layer_idx=None): + def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([AlignTextLayer(config) for i in range(config.num_hidden_layers)]) diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 7b3135faeba0..30c18b506112 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -181,7 +181,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): class AltRobertaSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, layer_idx=None): + def __init__(self, config, position_embedding_type=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -293,13 +293,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class AltRobertaAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, layer_idx=None): + def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = ALT_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, - position_embedding_type=position_embedding_type, - layer_idx=layer_idx, - ) + self.self = ALT_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation](config, position_embedding_type=position_embedding_type) self.output = AltRobertaSelfOutput(config) self.pruned_heads = set() @@ -425,7 +421,7 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->AltRoberta class AltRobertaEncoder(nn.Module): - def __init__(self, config, layer_idx=None): + def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([AltRobertaLayer(config) for i in range(config.num_hidden_layers)]) diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 301e51965648..afe7bdb06a3b 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -661,7 +661,7 @@ def _init_weights(self, module): # Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->ChineseCLIP class ChineseCLIPTextEncoder(nn.Module): - def __init__(self, config, layer_idx=None): + def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([ChineseCLIPTextLayer(config) for i in range(config.num_hidden_layers)]) diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index f32ef65e023e..707c04d0586a 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -1326,7 +1326,7 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->Clap class ClapTextEncoder(nn.Module): - def __init__(self, config, layer_idx=None): + def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([ClapTextLayer(config) for i in range(config.num_hidden_layers)]) diff --git a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py index 47862f30c7ef..a953f1472702 100644 --- a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -459,8 +459,7 @@ class DeepseekV2PreTrainedModel(PreTrainedModel): _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 426e1a2f90bb..6fd8fcc80781 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -360,7 +360,7 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->LayoutLM class LayoutLMEncoder(nn.Module): - def __init__(self, config, layer_idx=None): + def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([LayoutLMLayer(config) for i in range(config.num_hidden_layers)]) diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 93d8c11a7bdd..9fb5a7469bcc 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -521,7 +521,7 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->MarkupLM class MarkupLMEncoder(nn.Module): - def __init__(self, config, layer_idx=None): + def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([MarkupLMLayer(config) for i in range(config.num_hidden_layers)]) diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 20df061a90ac..d11a1eb60d87 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -332,7 +332,7 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->Splinter class SplinterEncoder(nn.Module): - def __init__(self, config, layer_idx=None): + def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([SplinterLayer(config) for i in range(config.num_hidden_layers)]) From d0f68d0ba78298a4d56046ffbcf061fdba1f1d74 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 10 Jul 2025 08:54:01 +0200 Subject: [PATCH 39/58] check `is_decoder` for cache --- src/transformers/models/bert/modeling_bert.py | 4 ++-- src/transformers/models/data2vec/modeling_data2vec_text.py | 2 +- src/transformers/models/doge/modeling_doge.py | 2 -- src/transformers/models/electra/modeling_electra.py | 2 +- src/transformers/models/ernie/modeling_ernie.py | 2 +- src/transformers/models/roberta/modeling_roberta.py | 5 ++--- .../roberta_prelayernorm/modeling_roberta_prelayernorm.py | 2 +- src/transformers/models/roc_bert/modeling_roc_bert.py | 2 +- .../models/xlm_roberta_xl/modeling_xlm_roberta_xl.py | 3 +-- 9 files changed, 10 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 3e1a501b9086..b1b2ec13b1f3 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -345,7 +345,7 @@ def forward( if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. logger.warning_once( - "XLMRobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " "the manual attention implementation, but specifying the manual implementation will be required from " "Transformers version v5.0.0 onwards. This warning can be removed using the argument " @@ -647,7 +647,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 7218305d7374..759374865522 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -495,7 +495,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py index 8aaf54648150..35bb53109b73 100644 --- a/src/transformers/models/doge/modeling_doge.py +++ b/src/transformers/models/doge/modeling_doge.py @@ -495,8 +495,6 @@ class DogePreTrainedModel(PreTrainedModel): _supports_flash_attn_3 = False _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True _supports_static_cache = False _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 5d45fd005236..9cc68dde875d 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -554,7 +554,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 424418e377f1..4abac47a726e 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -479,7 +479,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 11fca1540da2..7f0b787310fe 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -295,7 +295,7 @@ def forward( if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. logger.warning_once( - "XLMRobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "RobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " "the manual attention implementation, but specifying the manual implementation will be required from " "Transformers version v5.0.0 onwards. This warning can be removed using the argument " @@ -320,7 +320,6 @@ def forward( is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states - if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): is_updated = past_key_value.is_updated.get(self.layer_idx) @@ -604,7 +603,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 15c2c79b7d90..c6ca6803d6ab 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -487,7 +487,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index b9d89d6fb4ba..40a59350427c 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -606,7 +606,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index c6f931110ac0..5c0971c6b2d8 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -293,7 +293,7 @@ def forward( if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. logger.warning_once( - "XLMRobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "XLMRobertaXLSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " "the manual attention implementation, but specifying the manual implementation will be required from " "Transformers version v5.0.0 onwards. This warning can be removed using the argument " @@ -318,7 +318,6 @@ def forward( is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states - if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): is_updated = past_key_value.is_updated.get(self.layer_idx) From 72bc51a6c2eab72fa0fce9cb7e74974d384e89f8 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 10 Jul 2025 09:31:40 +0200 Subject: [PATCH 40/58] delete transpose for scores --- .../models/albert/modeling_albert.py | 20 +++------ .../modeling_audio_spectrogram_transformer.py | 29 ++++++++---- src/transformers/models/beit/modeling_beit.py | 21 +++------ src/transformers/models/bert/modeling_bert.py | 17 +++---- .../modeling_bert_generation.py | 23 +++++----- .../models/big_bird/modeling_big_bird.py | 24 +++------- .../modeling_bigbird_pegasus.py | 44 ++++++++++-------- .../models/blip/modeling_blip_text.py | 13 ++---- .../bridgetower/modeling_bridgetower.py | 25 ++++++----- src/transformers/models/bros/modeling_bros.py | 7 ++- .../models/camembert/modeling_camembert.py | 28 ++++++------ .../models/canine/modeling_canine.py | 14 ++---- .../models/convbert/modeling_convbert.py | 14 ++---- .../models/data2vec/modeling_data2vec_text.py | 23 +++++----- .../data2vec/modeling_data2vec_vision.py | 45 +++++++++++++------ src/transformers/models/deit/modeling_deit.py | 29 ++++++++---- .../models/dinat/modeling_dinat.py | 12 ++--- .../models/dinov2/modeling_dinov2.py | 29 ++++++++---- .../modeling_dinov2_with_registers.py | 29 ++++++++---- src/transformers/models/dpt/modeling_dpt.py | 29 ++++++++---- .../models/electra/modeling_electra.py | 23 +++++----- .../models/ernie/modeling_ernie.py | 23 +++++----- .../models/flava/modeling_flava.py | 14 ++---- src/transformers/models/git/modeling_git.py | 14 ++---- src/transformers/models/glpn/modeling_glpn.py | 24 ++++++---- .../grounding_dino/modeling_grounding_dino.py | 12 ++--- .../models/ibert/modeling_ibert.py | 11 ++--- .../models/ijepa/modeling_ijepa.py | 29 ++++++++---- .../models/layoutlmv2/modeling_layoutlmv2.py | 14 +++--- .../models/layoutlmv3/modeling_layoutlmv3.py | 14 ++---- .../models/lxmert/modeling_lxmert.py | 19 ++------ .../megatron_bert/modeling_megatron_bert.py | 23 +++++----- .../models/mobilebert/modeling_mobilebert.py | 16 ++----- .../models/mobilevit/modeling_mobilevit.py | 14 ++---- .../models/mpnet/modeling_mpnet.py | 16 ++----- src/transformers/models/mra/modeling_mra.py | 28 +++++------- .../nystromformer/modeling_nystromformer.py | 14 ++---- .../omdet_turbo/modeling_omdet_turbo.py | 12 ++--- .../models/rembert/modeling_rembert.py | 14 ++---- .../models/roberta/modeling_roberta.py | 23 +++++----- .../modeling_roberta_prelayernorm.py | 23 +++++----- .../models/roc_bert/modeling_roc_bert.py | 23 +++++----- .../models/roformer/modeling_roformer.py | 13 ++---- .../models/segformer/modeling_segformer.py | 12 ++--- .../models/superglue/modeling_superglue.py | 12 ++--- .../models/swin2sr/modeling_swin2sr.py | 25 ++++++----- .../models/swinv2/modeling_swinv2.py | 13 ++---- .../models/tapas/modeling_tapas.py | 13 ++---- .../models/videomae/modeling_videomae.py | 12 ++--- src/transformers/models/vilt/modeling_vilt.py | 14 ++---- .../visual_bert/modeling_visual_bert.py | 28 +++--------- src/transformers/models/vit/modeling_vit.py | 17 ++++--- .../models/vit_mae/modeling_vit_mae.py | 29 ++++++++---- .../models/vit_msn/modeling_vit_msn.py | 29 ++++++++---- .../modeling_vitpose_backbone.py | 29 ++++++++---- .../models/vivit/modeling_vivit.py | 29 ++++++++---- .../models/vjepa2/modeling_vjepa2.py | 17 ++----- .../xlm_roberta/modeling_xlm_roberta.py | 28 ++++++------ .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 23 +++++----- .../models/xlnet/modeling_xlnet.py | 4 +- src/transformers/models/xmod/modeling_xmod.py | 23 +++++----- .../models/yolos/modeling_yolos.py | 29 ++++++++---- src/transformers/models/yoso/modeling_yoso.py | 14 ++---- .../models/zoedepth/modeling_zoedepth.py | 20 +++++---- 64 files changed, 651 insertions(+), 660 deletions(-) diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 7285c8ba569a..6482b9db72f3 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -271,12 +271,6 @@ def __init__(self, config: AlbertConfig): self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - # Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def prune_heads(self, heads: list[int]) -> None: if len(heads) == 0: return @@ -302,13 +296,13 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: bool = False, ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + key_layer = self.key(hidden_states) + value_layer = self.value(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index a6baf638143b..7395a0f4e166 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -149,17 +149,28 @@ def __init__(self, config: ASTConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 2b5fb795eaea..c9d0f47eca9c 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -260,11 +260,6 @@ def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> N if self.has_relative_position_bias: self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -274,11 +269,10 @@ def forward( interpolate_pos_encoding: bool = False, resolution: Optional[tuple[int]] = None, ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -345,10 +339,9 @@ def forward( resolution=resolution, ) - mixed_query_layer = self.query(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) attn_bias = None if self.has_relative_position_bias: diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index b1b2ec13b1f3..8cc99d6980c9 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -217,11 +217,6 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -232,7 +227,9 @@ def forward( output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -256,8 +253,10 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(current_states) + value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -269,8 +268,6 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index df2bebe11672..02fbb04cd659 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -82,11 +82,6 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -97,7 +92,11 @@ def forward( output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -121,8 +120,14 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -134,8 +139,6 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 2b95ca52941b..94f50d174a8b 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -317,11 +317,6 @@ def __init__(self, config, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states, @@ -333,7 +328,7 @@ def forward( output_attentions=False, cache_position=None, ): - mixed_query_layer = self.query(hidden_states) + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # NOTE: BigBird has only cross attention layers so we can ignore self attn path current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states @@ -342,8 +337,8 @@ def forward( key_layer = past_key_value.key_cache[self.layer_idx] value_layer = past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.key(current_states)(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(current_states)(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -353,8 +348,6 @@ def forward( self.layer_idx, ) - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -411,11 +404,6 @@ def __init__(self, config, seed=None): self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states, @@ -438,9 +426,9 @@ def forward( if to_seq_length % to_block_size != 0: raise ValueError("Key/Value sided sequence length must be multiple of block size") - query_layer = self.transpose_for_scores(self.query(hidden_states)) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) context_layer, attention_probs = self.bigbird_block_sparse_attention( query_layer, diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 06ea62bf8765..496c3b469a10 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -128,11 +128,6 @@ def __init__(self, config, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states, @@ -144,7 +139,11 @@ def forward( output_attentions=False, cache_position=None, ): - mixed_query_layer = self.query(hidden_states) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # NOTE: BigBirdPegasus has only cross attention layers so we can ignore self attn path current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states @@ -153,8 +152,12 @@ def forward( key_layer = past_key_value.key_cache[self.layer_idx] value_layer = past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.key(current_states)( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + value_layer = self.value(current_states)( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -164,8 +167,6 @@ def forward( self.layer_idx, ) - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -223,11 +224,6 @@ def __init__(self, config, seed=None): self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states, @@ -250,9 +246,21 @@ def forward( if to_seq_length % to_block_size != 0: raise ValueError("Key/Value sided sequence length must be multiple of block size") - query_layer = self.transpose_for_scores(self.query(hidden_states)) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) context_layer, attention_probs = self.bigbird_block_sparse_attention( query_layer, diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 49c168f6b0a5..ff70fdbc7026 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -138,11 +138,6 @@ def save_attention_map(self, attention_map): def get_attention_map(self): return self.attention_map - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -154,7 +149,7 @@ def forward( output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -179,8 +174,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.key(current_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(current_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -192,8 +187,6 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index eba0042df6f4..fc72dacb3402 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -429,11 +429,6 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -444,7 +439,11 @@ def forward( output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -468,8 +467,14 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -481,8 +486,6 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -774,7 +777,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index 0677038e7d4b..dd8ecf2c9ebe 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -184,7 +184,7 @@ def forward( class BrosSelfAttention(nn.Module): - def __init__(self, config, layer_idx=None): + def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -207,7 +207,6 @@ def __init__(self, config, layer_idx=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder - self.layer_idx = layer_idx @deprecate_kwarg("past_key_value", version="4.54.0") def forward( @@ -310,9 +309,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class BrosAttention(nn.Module): - def __init__(self, config, layer_idx=None): + def __init__(self, config): super().__init__() - self.self = BrosSelfAttention(config, layer_idx=layer_idx) + self.self = BrosSelfAttention(config) self.output = BrosSelfOutput(config) self.pruned_heads = set() diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index d7698dc3bdf1..3ddad9b6b90d 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -167,11 +167,6 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -182,7 +177,11 @@ def forward( output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -206,8 +205,14 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -219,8 +224,6 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -296,7 +299,7 @@ def forward( if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. logger.warning_once( - "XLMCamembertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "CamembertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " "the manual attention implementation, but specifying the manual implementation will be required from " "Transformers version v5.0.0 onwards. This warning can be removed using the argument " @@ -321,7 +324,6 @@ def forward( is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states - if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): is_updated = past_key_value.is_updated.get(self.layer_idx) @@ -605,7 +607,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index e19fe2534e1c..29e544e7c5be 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -410,11 +410,6 @@ def __init__(self, config): self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, from_tensor: torch.Tensor, @@ -423,16 +418,15 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - mixed_query_layer = self.query(from_tensor) + batch_size, seq_length, _ = hidden_states.shape # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. - key_layer = self.transpose_for_scores(self.key(to_tensor)) - value_layer = self.transpose_for_scores(self.value(to_tensor)) - - query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.key(to_tensor).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(to_tensor).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + mixed_query_layer = self.query(from_tensor).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index bdac1fecc1c3..95c082b0c28a 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -325,11 +325,6 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -338,8 +333,7 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - batch_size = hidden_states.size(0) + batch_size, seq_length, _ = hidden_states.shape # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. @@ -353,9 +347,9 @@ def forward( mixed_key_conv_attn_layer = self.key_conv_attn_layer(hidden_states.transpose(1, 2)) mixed_key_conv_attn_layer = mixed_key_conv_attn_layer.transpose(1, 2) - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = mixed_key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = mixed_value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) conv_attn_layer = torch.multiply(mixed_key_conv_attn_layer, mixed_query_layer) conv_kernel_layer = self.conv_kernel_layer(conv_attn_layer) diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 759374865522..e37b7ec756c4 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -167,11 +167,6 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -182,7 +177,11 @@ def forward( output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -206,8 +205,14 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -219,8 +224,6 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index ede5404571ab..3f5249370786 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -261,11 +261,6 @@ def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = if self.has_relative_position_bias: self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -275,11 +270,22 @@ def forward( interpolate_pos_encoding: bool = False, resolution: Optional[tuple[int]] = None, ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -347,10 +353,21 @@ def forward( resolution=resolution, ) - mixed_query_layer = self.query(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attn_bias = None if self.has_relative_position_bias: diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 90b678466c49..573029bb119c 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -214,17 +214,28 @@ def __init__(self, config: DeiTConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/dinat/modeling_dinat.py b/src/transformers/models/dinat/modeling_dinat.py index 0618e0d2ded8..63a438de0b33 100644 --- a/src/transformers/models/dinat/modeling_dinat.py +++ b/src/transformers/models/dinat/modeling_dinat.py @@ -270,19 +270,15 @@ def __init__(self, config, dim, num_heads, kernel_size, dilation): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 3, 1, 2, 4) - def forward( self, hidden_states: torch.Tensor, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: - query_layer = self.transpose_for_scores(self.query(hidden_states)) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # Apply the scale factor before computing attention weights. It's usually more efficient because # attention weights are typically a bigger tensor compared to query. diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 6b98e3fa8d65..072226326213 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -202,17 +202,28 @@ def __init__(self, config: Dinov2Config) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py index 7d37f00daa91..fdebf7d99031 100644 --- a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +++ b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py @@ -223,17 +223,28 @@ def __init__(self, config: Dinov2WithRegistersConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 4d0a072dcc2d..bf91458df2cc 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -326,17 +326,28 @@ def __init__(self, config: DPTConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 9cc68dde875d..ef38d43a3e01 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -228,11 +228,6 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -243,7 +238,11 @@ def forward( output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -267,8 +266,14 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -280,8 +285,6 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 4abac47a726e..287eea8c7016 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -153,11 +153,6 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -168,7 +163,11 @@ def forward( output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -192,8 +191,14 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -205,8 +210,6 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index a1a30f369a10..6168873c5980 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -442,11 +442,6 @@ def __init__(self, config: FlavaPossibleConfigs) -> None: self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -454,11 +449,10 @@ def forward( head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index ed05d9302ce1..6f4c29a3d2a4 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -151,11 +151,6 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -165,11 +160,12 @@ def forward( output_attentions: Optional[bool] = False, pixel_values_present: Optional[bool] = False, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) cutoff = self.image_patch_tokens if pixel_values_present else 0 - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) if past_key_value is not None: # NOTE: like in other caches, we store the text component. In GIT it means we discard the image component. key_layer_past, value_layer_past = past_key_value.update( @@ -178,8 +174,6 @@ def forward( key_layer = torch.cat([key_layer[:, :, :cutoff, :], key_layer_past], dim=2) value_layer = torch.cat([value_layer[:, :, :cutoff, :], value_layer_past], dim=2) - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/glpn/modeling_glpn.py b/src/transformers/models/glpn/modeling_glpn.py index 8715a09613a3..b21d2f14d765 100755 --- a/src/transformers/models/glpn/modeling_glpn.py +++ b/src/transformers/models/glpn/modeling_glpn.py @@ -126,11 +126,6 @@ def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ ) self.layer_norm = nn.LayerNorm(hidden_size) - def transpose_for_scores(self, hidden_states): - new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - hidden_states = hidden_states.view(new_shape) - return hidden_states.permute(0, 2, 1, 3) - def forward( self, hidden_states, @@ -138,7 +133,12 @@ def forward( width, output_attentions=False, ): - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) if self.sr_ratio > 1: batch_size, seq_len, num_channels = hidden_states.shape @@ -150,8 +150,16 @@ def forward( hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1) hidden_states = self.layer_norm(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index 743f74a1215b..db06bca8e0b2 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -1183,11 +1183,6 @@ def __init__(self, config, num_attention_heads=None): self.dropout = nn.Dropout(config.attention_dropout) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, queries: torch.Tensor, @@ -1196,9 +1191,10 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: - query_layer = self.transpose_for_scores(self.query(queries)) - key_layer = self.transpose_for_scores(self.key(keys)) - value_layer = self.transpose_for_scores(self.value(values)) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(queries).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = self.key(keys).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(values).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/ibert/modeling_ibert.py b/src/transformers/models/ibert/modeling_ibert.py index 89fd716f885f..064aec235484 100644 --- a/src/transformers/models/ibert/modeling_ibert.py +++ b/src/transformers/models/ibert/modeling_ibert.py @@ -226,11 +226,6 @@ def __init__(self, config): self.softmax = IntSoftmax(self.act_bit, quant_mode=self.quant_mode, force_dequant=config.force_dequant) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states, @@ -254,9 +249,9 @@ def forward( ) # Transpose - query_layer = self.transpose_for_scores(query_layer) - key_layer = self.transpose_for_scores(key_layer) - value_layer = self.transpose_for_scores(value_layer) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index 6c2df8d2bbe7..f76067d6a6ab 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -230,17 +230,28 @@ def __init__(self, config: IJepaConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index 66637bedd8d2..7f6a861a674d 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -123,11 +123,6 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def compute_qkv(self, hidden_states): if self.fast_qkv: qkv = self.qkv_linear(hidden_states) @@ -154,12 +149,13 @@ def forward( rel_pos=None, rel_2d_pos=None, ): - q, k, v = self.compute_qkv(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query, key, value = self.compute_qkv(hidden_states) # (B, L, H*D) -> (B, H, L, D) - query_layer = self.transpose_for_scores(q) - key_layer = self.transpose_for_scores(k) - value_layer = self.transpose_for_scores(v) + query_layer = query.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = key.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = value.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) query_layer = query_layer / math.sqrt(self.attention_head_size) # [BSZ, NAT, L, L] diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 05f662b12a9f..04b40897678f 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -245,11 +245,6 @@ def __init__(self, config): self.has_relative_attention_bias = config.has_relative_attention_bias self.has_spatial_attention_bias = config.has_spatial_attention_bias - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def cogview_attention(self, attention_scores, alpha=32): """ https://huggingface.co/papers/2105.13290 Section 2.4 Stabilization of training: Precision Bottleneck Relaxation @@ -271,11 +266,10 @@ def forward( rel_pos=None, rel_2d_pos=None, ): - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. # The attention scores QT K/√d could be significantly larger than input elements, and result in overflow. diff --git a/src/transformers/models/lxmert/modeling_lxmert.py b/src/transformers/models/lxmert/modeling_lxmert.py index bc8a84d17e01..897e1e5624cb 100644 --- a/src/transformers/models/lxmert/modeling_lxmert.py +++ b/src/transformers/models/lxmert/modeling_lxmert.py @@ -319,22 +319,11 @@ def __init__(self, config, ctx_dim=None): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads, - self.attention_head_size, - ) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward(self, hidden_states, context, attention_mask=None, output_attentions=False): - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(context) - mixed_value_layer = self.value(context) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = self.key(context).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(context).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index b96a023f50cd..45153d5ddbc6 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -206,11 +206,6 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -221,7 +216,11 @@ def forward( output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -245,8 +244,14 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -258,8 +263,6 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index b1c267c959bc..0ce8b0e9df6e 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -231,11 +231,6 @@ def __init__(self, config): ) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, query_tensor: torch.Tensor, @@ -245,13 +240,10 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(query_tensor) - mixed_key_layer = self.key(key_tensor) - mixed_value_layer = self.value(value_tensor) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(query_tensor).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = self.key(key_tensor).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(value_tensor).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index 1b483fe958c0..60138fd892bb 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -211,17 +211,11 @@ def __init__(self, config: MobileViTConfig, hidden_size: int) -> None: self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/mpnet/modeling_mpnet.py b/src/transformers/models/mpnet/modeling_mpnet.py index efa74f191f13..145516d7a141 100644 --- a/src/transformers/models/mpnet/modeling_mpnet.py +++ b/src/transformers/models/mpnet/modeling_mpnet.py @@ -144,11 +144,6 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states, @@ -158,13 +153,10 @@ def forward( output_attentions=False, **kwargs, ): - q = self.q(hidden_states) - k = self.k(hidden_states) - v = self.v(hidden_states) - - q = self.transpose_for_scores(q) - k = self.transpose_for_scores(k) - v = self.transpose_for_scores(v) + batch_size, seq_length, _ = hidden_states.shape + q = self.q(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + k = self.k(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + v = self.v(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(q, k.transpose(-1, -2)) diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index a7fd783d848d..ea22bb196a0b 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -555,32 +555,24 @@ def __init__(self, config, position_embedding_type=None): self.initial_prior_first_n_blocks = config.initial_prior_first_n_blocks self.initial_prior_diagonal_n_blocks = config.initial_prior_diagonal_n_blocks - def transpose_for_scores(self, layer): - new_layer_shape = layer.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - layer = layer.view(*new_layer_shape) - return layer.permute(0, 2, 1, 3) - def forward(self, hidden_states, attention_mask=None): - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) - - batch_size, num_heads, seq_len, head_dim = query_layer.size() + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # revert changes made by get_extended_attention_mask attention_mask = 1.0 + attention_mask / 10000.0 attention_mask = ( - attention_mask.squeeze().repeat(1, num_heads, 1).reshape(batch_size * num_heads, seq_len).int() + attention_mask.squeeze().repeat(1, self.num_attention_heads, 1).reshape(batch_size * self.num_attention_heads, seq_len).int() ) # The CUDA kernels are most efficient with inputs whose size is a multiple of a GPU's warp size (32). Inputs # smaller than this are padded with zeros. gpu_warp_size = 32 - if head_dim < gpu_warp_size: - pad_size = batch_size, num_heads, seq_len, gpu_warp_size - head_dim + if self.attention_head_size < gpu_warp_size: + pad_size = batch_size, self.num_attention_heads, seq_len, gpu_warp_size - self.attention_head_size query_layer = torch.cat([query_layer, torch.zeros(pad_size, device=query_layer.device)], dim=-1) key_layer = torch.cat([key_layer, torch.zeros(pad_size, device=key_layer.device)], dim=-1) @@ -597,10 +589,10 @@ def forward(self, hidden_states, attention_mask=None): initial_prior_diagonal_n_blocks=self.initial_prior_diagonal_n_blocks, ) - if head_dim < gpu_warp_size: - context_layer = context_layer[:, :, :, :head_dim] + if self.attention_head_size < gpu_warp_size: + context_layer = context_layer[:, :, :, :self.attention_head_size] - context_layer = context_layer.reshape(batch_size, num_heads, seq_len, head_dim) + context_layer = context_layer.reshape(batch_size, self.num_attention_heads, seq_len, self.attention_head_size) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index f5b940157ded..cada50127fc8 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -167,17 +167,11 @@ def iterative_inv(self, mat, n_iter=6): ) return value - def transpose_for_scores(self, layer): - new_layer_shape = layer.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - layer = layer.view(*new_layer_shape) - return layer.permute(0, 2, 1, 3) - def forward(self, hidden_states, attention_mask=None, output_attentions=False): - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) query_layer = query_layer / math.sqrt(math.sqrt(self.attention_head_size)) key_layer = key_layer / math.sqrt(math.sqrt(self.attention_head_size)) diff --git a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py index 9bac40553d9f..77e16e419e23 100644 --- a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py @@ -494,11 +494,6 @@ def __init__(self, config, hidden_size, num_attention_heads, dropout): self.out_proj = nn.Linear(hidden_size, hidden_size) self.dropout = nn.Dropout(dropout) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, queries: torch.Tensor, @@ -507,9 +502,10 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: - query_layer = self.transpose_for_scores(self.query(queries)) - key_layer = self.transpose_for_scores(self.key(keys)) - value_layer = self.transpose_for_scores(self.value(values)) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(queries).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = self.key(keys).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(values).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 07b06bd489a0..15eec9d29396 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -221,11 +221,6 @@ def __init__(self, config, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -236,7 +231,8 @@ def forward( output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -260,8 +256,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.key(current_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(current_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -273,8 +269,6 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 7f0b787310fe..29c28557d184 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -166,11 +166,6 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -181,7 +176,11 @@ def forward( output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -205,8 +204,14 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -218,8 +223,6 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index c6ca6803d6ab..604c2d91b5f7 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -165,11 +165,6 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -180,7 +175,11 @@ def forward( output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -204,8 +203,14 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -217,8 +222,6 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 40a59350427c..487ee3c28754 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -280,11 +280,6 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -295,7 +290,11 @@ def forward( output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -319,8 +318,14 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -332,8 +337,6 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index bf20f6c3bc5c..c35d9f76a035 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -211,11 +211,6 @@ def __init__(self, config, layer_idx=None): self.rotary_value = config.rotary_value self.layer_idx = layer_idx - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states, @@ -227,8 +222,8 @@ def forward( output_attentions=False, cache_position=None, ): - mixed_query_layer = self.query(hidden_states) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. @@ -251,8 +246,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.key(current_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(current_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation diff --git a/src/transformers/models/segformer/modeling_segformer.py b/src/transformers/models/segformer/modeling_segformer.py index 81c220446103..70cd1c67aa1c 100755 --- a/src/transformers/models/segformer/modeling_segformer.py +++ b/src/transformers/models/segformer/modeling_segformer.py @@ -152,11 +152,6 @@ def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ ) self.layer_norm = nn.LayerNorm(hidden_size) - def transpose_for_scores(self, hidden_states): - new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - hidden_states = hidden_states.view(new_shape) - return hidden_states.permute(0, 2, 1, 3) - def forward( self, hidden_states, @@ -164,7 +159,8 @@ def forward( width, output_attentions=False, ): - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) if self.sr_ratio > 1: batch_size, seq_len, num_channels = hidden_states.shape @@ -176,8 +172,8 @@ def forward( hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1) hidden_states = self.layer_norm(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/superglue/modeling_superglue.py b/src/transformers/models/superglue/modeling_superglue.py index beae219fa66d..f12b12d58fe9 100644 --- a/src/transformers/models/superglue/modeling_superglue.py +++ b/src/transformers/models/superglue/modeling_superglue.py @@ -260,11 +260,6 @@ def __init__(self, config, position_embedding_type=None): self.is_decoder = config.is_decoder - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -281,9 +276,10 @@ def forward( current_states = encoder_hidden_states if is_cross_attention else hidden_states attention_mask = encoder_attention_mask if is_cross_attention else encoder_attention_mask - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size = hidden_states.shape[0] + key_layer = self.key(current_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(current_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index 4b16bc954dc4..de61e5b2d259 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -296,11 +296,6 @@ def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=[ self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -309,11 +304,21 @@ def forward( output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: batch_size, dim, num_channels = hidden_states.shape - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # cosine attention attention_scores = nn.functional.normalize(query_layer, dim=-1) @ nn.functional.normalize( diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index d18c126fe4b2..9bfbfb047073 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -467,11 +467,6 @@ def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=[ self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -480,11 +475,9 @@ def forward( output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: batch_size, dim, num_channels = hidden_states.shape - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # cosine attention attention_scores = nn.functional.normalize(query_layer, dim=-1) @ nn.functional.normalize( diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 221bb7544260..ab6a54e2f815 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -300,11 +300,6 @@ def __init__(self, config, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states, @@ -315,7 +310,7 @@ def forward( output_attentions=False, cache_position=None, ): - mixed_query_layer = self.query(hidden_states) + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -339,8 +334,8 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.key(current_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(current_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -352,8 +347,6 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 74293b39716b..d8960f54de94 100755 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -239,22 +239,18 @@ def __init__(self, config: VideoMAEConfig) -> None: self.q_bias = None self.v_bias = None - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: + batch_size, seq_length, _ = hidden_states.shape k_bias = torch.zeros_like(self.v_bias, requires_grad=False) if self.q_bias is not None else None keys = nn.functional.linear(input=hidden_states, weight=self.key.weight, bias=k_bias) values = nn.functional.linear(input=hidden_states, weight=self.value.weight, bias=self.v_bias) queries = nn.functional.linear(input=hidden_states, weight=self.query.weight, bias=self.q_bias) - key_layer = self.transpose_for_scores(keys) - value_layer = self.transpose_for_scores(values) - query_layer = self.transpose_for_scores(queries) + key_layer = keys.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = values.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = queries.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index f54cc65822d8..014ec1b4b410 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -328,17 +328,11 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False): - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 305cc68a39ec..7369bab5a9a9 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -193,11 +193,6 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states, @@ -205,12 +200,10 @@ def forward( head_mask=None, output_attentions=False, ): - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -1367,21 +1360,14 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward(self, query, key, attention_mask): + batch_size, seq_length, _ = query.shape attention_mask = attention_mask.to(query.dtype) attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) attention_mask = (1.0 - attention_mask) * torch.finfo(query.dtype).min - mixed_query_layer = self.query(query) - mixed_key_layer = self.key(key) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) + query_layer = self.query(query).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = self.key(key).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index f898b8138264..719d803353cb 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -218,17 +218,16 @@ def __init__(self, config: ViTConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 8f0376863934..9792ee253cd9 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -378,17 +378,28 @@ def __init__(self, config: ViTMAEConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index c8b1b6f6cf67..88670bfef070 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -216,17 +216,28 @@ def __init__(self, config: ViTMSNConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py index dad8dfe9c4f7..ed9c1d88de51 100644 --- a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +++ b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py @@ -148,17 +148,28 @@ def __init__(self, config: VitPoseBackboneConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index 8a2f3177604e..fefa6e9569d5 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -209,17 +209,28 @@ def __init__(self, config: VivitConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/vjepa2/modeling_vjepa2.py b/src/transformers/models/vjepa2/modeling_vjepa2.py index efbf452d1fc1..0b940cc45b01 100644 --- a/src/transformers/models/vjepa2/modeling_vjepa2.py +++ b/src/transformers/models/vjepa2/modeling_vjepa2.py @@ -243,14 +243,6 @@ def __init__( self.scaling = self.attention_head_size**-0.5 self.is_causal = False - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads, - self.attention_head_size, - ) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def _get_frame_pos(self, ids): tokens_per_frame = int(self.grid_size * self.grid_size) return ids // tokens_per_frame @@ -309,11 +301,10 @@ def forward( output_attentions: bool = False, head_mask: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) pos_ids = self.get_position_ids(hidden_states, masks=position_mask) key_layer = self.apply_rotary_embeddings(key_layer, pos_ids) diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 84e36f72f1d3..8aa928d52038 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -167,11 +167,6 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -182,7 +177,11 @@ def forward( output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -206,8 +205,14 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -219,8 +224,6 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -296,7 +299,7 @@ def forward( if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. logger.warning_once( - "XLMXLMRobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "XLMRobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " "the manual attention implementation, but specifying the manual implementation will be required from " "Transformers version v5.0.0 onwards. This warning can be removed using the argument " @@ -321,7 +324,6 @@ def forward( is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states - if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): is_updated = past_key_value.is_updated.get(self.layer_idx) @@ -605,7 +607,7 @@ def forward( use_cache = False return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 5c0971c6b2d8..41ecdd83994c 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -164,11 +164,6 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -179,7 +174,11 @@ def forward( output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -203,8 +202,14 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -216,8 +221,6 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index 8cad1124a3a8..315d35bbd7f2 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -187,7 +187,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path): class XLNetRelativeAttention(nn.Module): - def __init__(self, config, layer_idx=None): + def __init__(self, config): super().__init__() if config.d_model % config.n_head != 0: @@ -214,7 +214,6 @@ def __init__(self, config, layer_idx=None): self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.dropout) - self.layer_idx = layer_idx def prune_heads(self, heads): raise NotImplementedError @@ -323,7 +322,6 @@ def forward( target_mapping=None, head_mask=None, output_attentions=False, - cache_position=None, ): if g is not None: # Two-stream attention with relative positional encoding. diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index a8fe528a764f..88567b6e94a3 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -164,11 +164,6 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -179,7 +174,11 @@ def forward( output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -203,8 +202,14 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -216,8 +221,6 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index bb9f034e81dd..8e395a6f9ef3 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -264,17 +264,28 @@ def __init__(self, config: YolosConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index da35490c59ed..913386a05795 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -341,17 +341,11 @@ def __init__(self, config, position_embedding_type=None): groups=config.num_attention_heads, ) - def transpose_for_scores(self, layer): - new_layer_shape = layer.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - layer = layer.view(*new_layer_shape) - return layer.permute(0, 2, 1, 3) - def forward(self, hidden_states, attention_mask=None, output_attentions=False): - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) if self.use_conv: conv_value_layer = self.conv(value_layer * attention_mask[:, None, :, None]) diff --git a/src/transformers/models/zoedepth/modeling_zoedepth.py b/src/transformers/models/zoedepth/modeling_zoedepth.py index 48ff8174186e..a0bbaf05979e 100644 --- a/src/transformers/models/zoedepth/modeling_zoedepth.py +++ b/src/transformers/models/zoedepth/modeling_zoedepth.py @@ -799,11 +799,6 @@ def __init__(self, hidden_size, num_attention_heads, dropout): self.dropout = nn.Dropout(dropout) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, queries: torch.Tensor, @@ -812,9 +807,18 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: - query_layer = self.transpose_for_scores(self.query(queries)) - key_layer = self.transpose_for_scores(self.key(keys)) - value_layer = self.transpose_for_scores(self.value(values)) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(queries) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(keys).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) + value_layer = ( + self.value(values).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) From ac739595d62001ce8aa922886e3312733c944341 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 10 Jul 2025 09:31:58 +0200 Subject: [PATCH 41/58] remove tuple cache from docs everywhere --- docs/source/en/cache_explanation.md | 2 +- src/transformers/utils/args_doc.py | 17 +++++------------ 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/docs/source/en/cache_explanation.md b/docs/source/en/cache_explanation.md index 6c31035234bb..13f310669200 100644 --- a/docs/source/en/cache_explanation.md +++ b/docs/source/en/cache_explanation.md @@ -141,7 +141,7 @@ The legacy format is essentially the same data structure but organized different - The tensors have the same shape `[batch_size, num_heads, seq_len, head_dim]`. - The format is less flexible and doesn't support features like quantization or offloading. -If your project depends on this legacy format, you can convert between [`DynamicCache`] and a tuple of tuples as shown below with the [`~DynamicCache.from_legacy_cache`] and [`DynamicCache.to_legacy_cache`] functions. This is helpful if you have custom logic for manipulating a cache in a specific format. +If your project depends on this legacy format, we recommend to convert to [`DynamicCache`] with [`~DynamicCache.from_legacy_cache`]. Note that legacy cache format is deprecated and not used anymore in `Transformers`. You can convert back to tuple format with [`DynamicCache.to_legacy_cache`] functions, which is helpful if you have custom logic for manipulating a cache in a specific format. ```py import torch diff --git a/src/transformers/utils/args_doc.py b/src/transformers/utils/args_doc.py index 98d1473dcf90..e437630d4832 100644 --- a/src/transformers/utils/args_doc.py +++ b/src/transformers/utils/args_doc.py @@ -352,17 +352,13 @@ class ModelArgs: blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - Two formats are allowed: - - a [`~cache_utils.Cache`] instance, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. + Only [`~cache_utils.Cache`] instance is allowed as input, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + If no `past_key_values` are passed, [`~cache_utils.DynamicCache`] will be initialized by default. - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. + The model will output the same cache format that is fed as input. - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + If `past_key_values` are used, the user is expected to input only unprocessed `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, unprocessed_length)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. """, "shape": None, @@ -939,9 +935,6 @@ class ClassAttrs: _supports_flex_attn = r""" Whether the model's attention implementation supports FlexAttention. """ - _supports_quantized_cache = r""" - Whether the model supports a `QuantoQuantizedCache` instance as `past_key_values`. - """ _supports_static_cache = r""" Whether the model supports a `StaticCache` instance as `past_key_values`. """ From fd84e6794130032d87271db900521463e4b51587 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 10 Jul 2025 09:58:50 +0200 Subject: [PATCH 42/58] fix tests --- .../models/albert/modeling_albert.py | 26 ++++++++++--- .../models/altclip/modeling_altclip.py | 4 +- src/transformers/models/beit/modeling_beit.py | 37 ++++++++++++++++--- src/transformers/models/bert/modeling_bert.py | 28 +++++++++++--- .../models/big_bird/modeling_big_bird.py | 33 ++++++++++++++--- .../modeling_bigbird_pegasus.py | 1 + .../models/blip/modeling_blip_text.py | 19 ++++++++-- .../models/canine/modeling_canine.py | 20 ++++++++-- .../models/convbert/modeling_convbert.py | 16 ++++++-- .../data2vec/modeling_data2vec_vision.py | 1 + .../models/dinat/modeling_dinat.py | 18 +++++++-- .../models/flava/modeling_flava.py | 18 +++++++-- src/transformers/models/git/modeling_git.py | 18 +++++++-- .../grounding_dino/modeling_grounding_dino.py | 16 ++++++-- .../models/ibert/modeling_ibert.py | 9 ++++- .../models/layoutlmv3/modeling_layoutlmv3.py | 18 +++++++-- .../models/lxmert/modeling_lxmert.py | 16 ++++++-- .../models/mobilebert/modeling_mobilebert.py | 20 ++++++++-- .../models/mobilevit/modeling_mobilevit.py | 18 +++++++-- .../models/mpnet/modeling_mpnet.py | 18 +++++++-- src/transformers/models/mra/modeling_mra.py | 27 +++++++++++--- .../nystromformer/modeling_nystromformer.py | 18 +++++++-- .../omdet_turbo/modeling_omdet_turbo.py | 16 ++++++-- .../models/rembert/modeling_rembert.py | 18 +++++++-- .../models/roberta/modeling_roberta.py | 16 ++++++-- .../models/roformer/modeling_roformer.py | 18 +++++++-- .../models/segformer/modeling_segformer.py | 18 +++++++-- .../models/superglue/modeling_superglue.py | 18 +++++++-- .../models/swinv2/modeling_swinv2.py | 18 +++++++-- .../models/tapas/modeling_tapas.py | 19 ++++++++-- src/transformers/models/vilt/modeling_vilt.py | 18 +++++++-- .../visual_bert/modeling_visual_bert.py | 26 ++++++++++--- src/transformers/models/vit/modeling_vit.py | 18 +++++++-- .../models/vjepa2/modeling_vjepa2.py | 18 +++++++-- .../xlm_roberta/modeling_xlm_roberta.py | 16 ++++++-- src/transformers/models/yoso/modeling_yoso.py | 18 +++++++-- .../models/zoedepth/modeling_zoedepth.py | 2 +- 37 files changed, 535 insertions(+), 121 deletions(-) diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 6482b9db72f3..005665d324b9 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -300,9 +300,13 @@ def forward( query_layer = self.query(hidden_states) key_layer = self.key(hidden_states) value_layer = self.value(hidden_states) - query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -372,9 +376,21 @@ def forward( return super().forward(hidden_states, attention_mask, output_attentions=output_attentions) batch_size, seq_len, _ = hidden_states.size() - query_layer = self.transpose_for_scores(self.query(hidden_states)) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 30c18b506112..c770dd5adcea 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -295,7 +295,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class AltRobertaAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = ALT_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation](config, position_embedding_type=position_embedding_type) + self.self = ALT_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = AltRobertaSelfOutput(config) self.pruned_heads = set() diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index c9d0f47eca9c..9c964467e513 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -270,9 +270,21 @@ def forward( resolution: Optional[tuple[int]] = None, ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: batch_size, seq_length, _ = hidden_states.shape - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -339,9 +351,22 @@ def forward( resolution=resolution, ) - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attn_bias = None if self.has_relative_position_bias: diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 8cc99d6980c9..ec56b8469e41 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -229,7 +229,9 @@ def forward( ) -> tuple[torch.Tensor]: batch_size, seq_length, _ = hidden_states.shape query_layer = self.query(hidden_states) - query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -254,9 +256,13 @@ def forward( value_layer = curr_past_key_value.value_cache[self.layer_idx] else: key_layer = self.key(current_states) - key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) value_layer = self.value(current_states) - value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -360,7 +366,9 @@ def forward( bsz, tgt_len, _ = hidden_states.size() - query_layer = self.transpose_for_scores(self.query(hidden_states)) + query_layer = ( + self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention # mask needs to be such that the encoder's padding tokens are not attended to. @@ -384,8 +392,16 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = ( + self.key(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 94f50d174a8b..0986e050a325 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -328,7 +328,12 @@ def forward( output_attentions=False, cache_position=None, ): - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # NOTE: BigBird has only cross attention layers so we can ignore self attn path current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states @@ -337,8 +342,12 @@ def forward( key_layer = past_key_value.key_cache[self.layer_idx] value_layer = past_key_value.value_cache[self.layer_idx] else: - key_layer = self.key(current_states)(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(current_states)(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = self.key(current_states)( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + value_layer = self.value(current_states)( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation @@ -426,9 +435,21 @@ def forward( if to_seq_length % to_block_size != 0: raise ValueError("Key/Value sided sequence length must be multiple of block size") - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) context_layer, attention_probs = self.bigbird_block_sparse_attention( query_layer, diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 496c3b469a10..905d876ff018 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -139,6 +139,7 @@ def forward( output_attentions=False, cache_position=None, ): + batch_size, seq_length, _ = hidden_states.shape query_layer = ( self.query(hidden_states) .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index ff70fdbc7026..0123ff73e20c 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -149,7 +149,12 @@ def forward( output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -174,8 +179,16 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.key(current_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(current_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = ( + self.key(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index 29e544e7c5be..9866aad87a4e 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -418,15 +418,27 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - batch_size, seq_length, _ = hidden_states.shape + batch_size, seq_length, _ = from_tensor.shape # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. - key_layer = self.key(to_tensor).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(to_tensor).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - mixed_query_layer = self.query(from_tensor).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = ( + self.key(to_tensor) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(to_tensor) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(from_tensor) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index 95c082b0c28a..f6b554d0f12e 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -347,10 +347,18 @@ def forward( mixed_key_conv_attn_layer = self.key_conv_attn_layer(hidden_states.transpose(1, 2)) mixed_key_conv_attn_layer = mixed_key_conv_attn_layer.transpose(1, 2) - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - key_layer = mixed_key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = mixed_value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - conv_attn_layer = torch.multiply(mixed_key_conv_attn_layer, mixed_query_layer) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = mixed_key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = mixed_value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + conv_attn_layer = torch.multiply(mixed_key_conv_attn_layer, query_layer) conv_kernel_layer = self.conv_kernel_layer(conv_attn_layer) conv_kernel_layer = torch.reshape(conv_kernel_layer, [-1, self.conv_kernel_size, 1]) diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index 3f5249370786..2cf64ac21f81 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -353,6 +353,7 @@ def forward( resolution=resolution, ) + batch_size, seq_length, _ = hidden_states.shape query_layer = ( self.query(hidden_states) .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) diff --git a/src/transformers/models/dinat/modeling_dinat.py b/src/transformers/models/dinat/modeling_dinat.py index 63a438de0b33..140d16bd33b9 100644 --- a/src/transformers/models/dinat/modeling_dinat.py +++ b/src/transformers/models/dinat/modeling_dinat.py @@ -276,9 +276,21 @@ def forward( output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: batch_size, seq_length, _ = hidden_states.shape - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Apply the scale factor before computing attention weights. It's usually more efficient because # attention weights are typically a bigger tensor compared to query. diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 6168873c5980..64a61e66b520 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -450,9 +450,21 @@ def forward( output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: batch_size, seq_length, _ = hidden_states.shape - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 6f4c29a3d2a4..c6b1aa1f88db 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -161,11 +161,23 @@ def forward( pixel_values_present: Optional[bool] = False, ) -> tuple[torch.Tensor]: batch_size, seq_length, _ = hidden_states.shape - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) cutoff = self.image_patch_tokens if pixel_values_present else 0 - key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) if past_key_value is not None: # NOTE: like in other caches, we store the text component. In GIT it means we discard the image component. key_layer_past, value_layer_past = past_key_value.update( diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index db06bca8e0b2..197c99c57be6 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -1191,10 +1191,18 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: - batch_size, seq_length, _ = hidden_states.shape - query_layer = self.query(queries).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - key_layer = self.key(keys).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(values).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + batch_size, seq_length, _ = queries.shape + query_layer = ( + self.query(queries) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(keys).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) + value_layer = ( + self.value(values).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/ibert/modeling_ibert.py b/src/transformers/models/ibert/modeling_ibert.py index 064aec235484..1a58783e80d2 100644 --- a/src/transformers/models/ibert/modeling_ibert.py +++ b/src/transformers/models/ibert/modeling_ibert.py @@ -249,9 +249,14 @@ def forward( ) # Transpose - query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + batch_size, seq_length, _ = hidden_states.shape + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 04b40897678f..8b5628541092 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -267,9 +267,21 @@ def forward( rel_2d_pos=None, ): batch_size, seq_length, _ = hidden_states.shape - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. # The attention scores QT K/√d could be significantly larger than input elements, and result in overflow. diff --git a/src/transformers/models/lxmert/modeling_lxmert.py b/src/transformers/models/lxmert/modeling_lxmert.py index 897e1e5624cb..4138cb0b82a9 100644 --- a/src/transformers/models/lxmert/modeling_lxmert.py +++ b/src/transformers/models/lxmert/modeling_lxmert.py @@ -321,9 +321,19 @@ def __init__(self, config, ctx_dim=None): def forward(self, hidden_states, context, attention_mask=None, output_attentions=False): batch_size, seq_length, _ = hidden_states.shape - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - key_layer = self.key(context).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(context).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(context).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) + value_layer = ( + self.value(context) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index 0ce8b0e9df6e..91508d099711 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -240,10 +240,22 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, ) -> tuple[torch.Tensor]: - batch_size, seq_length, _ = hidden_states.shape - query_layer = self.query(query_tensor).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - key_layer = self.key(key_tensor).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(value_tensor).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + batch_size, seq_length, _ = query_tensor.shape + query_layer = ( + self.query(query_tensor) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(key_tensor) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(value_tensor) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index 60138fd892bb..3f882b9850ff 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -213,9 +213,21 @@ def __init__(self, config: MobileViTConfig, hidden_size: int) -> None: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, seq_length, _ = hidden_states.shape - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/mpnet/modeling_mpnet.py b/src/transformers/models/mpnet/modeling_mpnet.py index 145516d7a141..82698f8ecfb2 100644 --- a/src/transformers/models/mpnet/modeling_mpnet.py +++ b/src/transformers/models/mpnet/modeling_mpnet.py @@ -154,9 +154,21 @@ def forward( **kwargs, ): batch_size, seq_length, _ = hidden_states.shape - q = self.q(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - k = self.k(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - v = self.v(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + q = ( + self.q(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + k = ( + self.k(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + v = ( + self.v(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(q, k.transpose(-1, -2)) diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index ea22bb196a0b..159299aa3053 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -556,15 +556,30 @@ def __init__(self, config, position_embedding_type=None): self.initial_prior_diagonal_n_blocks = config.initial_prior_diagonal_n_blocks def forward(self, hidden_states, attention_mask=None): - batch_size, seq_length, _ = hidden_states.shape - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + batch_size, seq_len, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # revert changes made by get_extended_attention_mask attention_mask = 1.0 + attention_mask / 10000.0 attention_mask = ( - attention_mask.squeeze().repeat(1, self.num_attention_heads, 1).reshape(batch_size * self.num_attention_heads, seq_len).int() + attention_mask.squeeze() + .repeat(1, self.num_attention_heads, 1) + .reshape(batch_size * self.num_attention_heads, seq_len) + .int() ) # The CUDA kernels are most efficient with inputs whose size is a multiple of a GPU's warp size (32). Inputs @@ -590,7 +605,7 @@ def forward(self, hidden_states, attention_mask=None): ) if self.attention_head_size < gpu_warp_size: - context_layer = context_layer[:, :, :, :self.attention_head_size] + context_layer = context_layer[:, :, :, : self.attention_head_size] context_layer = context_layer.reshape(batch_size, self.num_attention_heads, seq_len, self.attention_head_size) diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index cada50127fc8..babd8acc09f7 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -169,9 +169,21 @@ def iterative_inv(self, mat, n_iter=6): def forward(self, hidden_states, attention_mask=None, output_attentions=False): batch_size, seq_length, _ = hidden_states.shape - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) query_layer = query_layer / math.sqrt(math.sqrt(self.attention_head_size)) key_layer = key_layer / math.sqrt(math.sqrt(self.attention_head_size)) diff --git a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py index 77e16e419e23..ab1ae0b9744b 100644 --- a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py @@ -502,10 +502,18 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: - batch_size, seq_length, _ = hidden_states.shape - query_layer = self.query(queries).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - key_layer = self.key(keys).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(values).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + batch_size, seq_length, _ = queries.shape + query_layer = ( + self.query(queries) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(keys).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) + value_layer = ( + self.value(values).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 15eec9d29396..565771b1efc4 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -232,7 +232,11 @@ def forward( cache_position: Optional[torch.Tensor] = None, ) -> tuple: batch_size, seq_length, _ = hidden_states.shape - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -256,8 +260,16 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.key(current_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(current_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = ( + self.key(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 29c28557d184..ee1de37d0244 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -316,7 +316,9 @@ def forward( bsz, tgt_len, _ = hidden_states.size() - query_layer = self.transpose_for_scores(self.query(hidden_states)) + query_layer = ( + self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention # mask needs to be such that the encoder's padding tokens are not attended to. @@ -340,8 +342,16 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = ( + self.key(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index c35d9f76a035..8360071b62b2 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -223,7 +223,11 @@ def forward( cache_position=None, ): batch_size, seq_length, _ = hidden_states.shape - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. @@ -246,8 +250,16 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.key(current_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(current_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = ( + self.key(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation diff --git a/src/transformers/models/segformer/modeling_segformer.py b/src/transformers/models/segformer/modeling_segformer.py index 70cd1c67aa1c..b998e0546d8f 100755 --- a/src/transformers/models/segformer/modeling_segformer.py +++ b/src/transformers/models/segformer/modeling_segformer.py @@ -160,7 +160,11 @@ def forward( output_attentions=False, ): batch_size, seq_length, _ = hidden_states.shape - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) if self.sr_ratio > 1: batch_size, seq_len, num_channels = hidden_states.shape @@ -172,8 +176,16 @@ def forward( hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1) hidden_states = self.layer_norm(hidden_states) - key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/superglue/modeling_superglue.py b/src/transformers/models/superglue/modeling_superglue.py index f12b12d58fe9..56506bc7a235 100644 --- a/src/transformers/models/superglue/modeling_superglue.py +++ b/src/transformers/models/superglue/modeling_superglue.py @@ -277,9 +277,21 @@ def forward( attention_mask = encoder_attention_mask if is_cross_attention else encoder_attention_mask batch_size = hidden_states.shape[0] - key_layer = self.key(current_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(current_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = ( + self.key(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index 9bfbfb047073..14ec4791ac8d 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -475,9 +475,21 @@ def forward( output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: batch_size, dim, num_channels = hidden_states.shape - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # cosine attention attention_scores = nn.functional.normalize(query_layer, dim=-1) @ nn.functional.normalize( diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index ab6a54e2f815..6fb656ebdcf0 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -310,7 +310,12 @@ def forward( output_attentions=False, cache_position=None, ): - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be @@ -334,8 +339,16 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.key(current_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(current_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = ( + self.key(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 014ec1b4b410..2600605fc604 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -330,9 +330,21 @@ def __init__(self, config): def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False): batch_size, seq_length, _ = hidden_states.shape - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 7369bab5a9a9..255406c6ce2f 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -201,9 +201,21 @@ def forward( output_attentions=False, ): batch_size, seq_length, _ = hidden_states.shape - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -1366,8 +1378,12 @@ def forward(self, query, key, attention_mask): attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) attention_mask = (1.0 - attention_mask) * torch.finfo(query.dtype).min - query_layer = self.query(query).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - key_layer = self.key(key).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = ( + self.query(query).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) + key_layer = ( + self.key(key).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 719d803353cb..7bae7242a300 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -225,9 +225,21 @@ def forward( output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: batch_size, seq_length, _ = hidden_states.shape - key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/vjepa2/modeling_vjepa2.py b/src/transformers/models/vjepa2/modeling_vjepa2.py index 0b940cc45b01..c199aadd1cf7 100644 --- a/src/transformers/models/vjepa2/modeling_vjepa2.py +++ b/src/transformers/models/vjepa2/modeling_vjepa2.py @@ -302,9 +302,21 @@ def forward( head_mask: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: batch_size, seq_length, _ = hidden_states.shape - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) pos_ids = self.get_position_ids(hidden_states, masks=position_mask) key_layer = self.apply_rotary_embeddings(key_layer, pos_ids) diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 8aa928d52038..02997dde4faa 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -317,7 +317,9 @@ def forward( bsz, tgt_len, _ = hidden_states.size() - query_layer = self.transpose_for_scores(self.query(hidden_states)) + query_layer = ( + self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention # mask needs to be such that the encoder's padding tokens are not attended to. @@ -341,8 +343,16 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = ( + self.key(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index 913386a05795..1d999ea4ca50 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -343,9 +343,21 @@ def __init__(self, config, position_embedding_type=None): def forward(self, hidden_states, attention_mask=None, output_attentions=False): batch_size, seq_length, _ = hidden_states.shape - query_layer = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - key_layer = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) - value_layer = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) if self.use_conv: conv_value_layer = self.conv(value_layer * attention_mask[:, None, :, None]) diff --git a/src/transformers/models/zoedepth/modeling_zoedepth.py b/src/transformers/models/zoedepth/modeling_zoedepth.py index a0bbaf05979e..900425633847 100644 --- a/src/transformers/models/zoedepth/modeling_zoedepth.py +++ b/src/transformers/models/zoedepth/modeling_zoedepth.py @@ -807,7 +807,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: - batch_size, seq_length, _ = hidden_states.shape + batch_size, seq_length, _ = queries.shape query_layer = ( self.query(queries) .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) From 7c2d5b679fce75de9814f53f096502919965895e Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 10 Jul 2025 10:18:25 +0200 Subject: [PATCH 43/58] fix copies --- .../models/big_bird/modeling_big_bird.py | 16 ++++++++++------ .../models/camembert/modeling_camembert.py | 16 +++++++++++++--- .../models/convbert/modeling_convbert.py | 11 +++++------ .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 16 +++++++++++++--- 4 files changed, 41 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 0986e050a325..f5dab19dd0ba 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -342,12 +342,16 @@ def forward( key_layer = past_key_value.key_cache[self.layer_idx] value_layer = past_key_value.value_cache[self.layer_idx] else: - key_layer = self.key(current_states)( - batch_size, -1, self.num_attention_heads, self.attention_head_size - ).transpose(1, 2) - value_layer = self.value(current_states)( - batch_size, -1, self.num_attention_heads, self.attention_head_size - ).transpose(1, 2) + key_layer = ( + self.key(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 3ddad9b6b90d..e49c775b8a31 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -317,7 +317,9 @@ def forward( bsz, tgt_len, _ = hidden_states.size() - query_layer = self.transpose_for_scores(self.query(hidden_states)) + query_layer = ( + self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention # mask needs to be such that the encoder's padding tokens are not attended to. @@ -341,8 +343,16 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = ( + self.key(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index f6b554d0f12e..a43ebfa0259d 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -347,18 +347,17 @@ def forward( mixed_key_conv_attn_layer = self.key_conv_attn_layer(hidden_states.transpose(1, 2)) mixed_key_conv_attn_layer = mixed_key_conv_attn_layer.transpose(1, 2) - query_layer = ( - self.query(hidden_states) - .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) - .transpose(1, 2) - ) + mixed_query_layer = self.query(hidden_states) + query_layer = mixed_query_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) key_layer = mixed_key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( 1, 2 ) value_layer = mixed_value_layer.view( batch_size, -1, self.num_attention_heads, self.attention_head_size ).transpose(1, 2) - conv_attn_layer = torch.multiply(mixed_key_conv_attn_layer, query_layer) + conv_attn_layer = torch.multiply(mixed_key_conv_attn_layer, mixed_query_layer) conv_kernel_layer = self.conv_kernel_layer(conv_attn_layer) conv_kernel_layer = torch.reshape(conv_kernel_layer, [-1, self.conv_kernel_size, 1]) diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 41ecdd83994c..4f1c2a02bacb 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -314,7 +314,9 @@ def forward( bsz, tgt_len, _ = hidden_states.size() - query_layer = self.transpose_for_scores(self.query(hidden_states)) + query_layer = ( + self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention # mask needs to be such that the encoder's padding tokens are not attended to. @@ -338,8 +340,16 @@ def forward( key_layer = curr_past_key_value.key_cache[self.layer_idx] value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = ( + self.key(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation From 122564eb45045a8aec73e7ad7b90da2fdc4104a7 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 10 Jul 2025 10:37:10 +0200 Subject: [PATCH 44/58] fix copies once more --- .../bigbird_pegasus/modeling_bigbird_pegasus.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 905d876ff018..5f9d3264d2c8 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -153,12 +153,16 @@ def forward( key_layer = past_key_value.key_cache[self.layer_idx] value_layer = past_key_value.value_cache[self.layer_idx] else: - key_layer = self.key(current_states)( - batch_size, -1, self.num_attention_heads, self.attention_head_size - ).transpose(1, 2) - value_layer = self.value(current_states)( - batch_size, -1, self.num_attention_heads, self.attention_head_size - ).transpose(1, 2) + key_layer = ( + self.key(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation From b265287357d4cf0e8202d2db8c751eea67991e44 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 10 Jul 2025 10:46:39 +0200 Subject: [PATCH 45/58] properly deprecate `encoder_attention_mask` in Bert-like models --- src/transformers/models/bert/modeling_bert.py | 13 +++++++++++++ .../bert_generation/modeling_bert_generation.py | 7 +++++++ .../models/bridgetower/modeling_bridgetower.py | 6 ++++++ .../models/camembert/modeling_camembert.py | 3 +++ .../models/data2vec/modeling_data2vec_text.py | 6 ++++++ src/transformers/models/electra/modeling_electra.py | 7 +++++++ src/transformers/models/ernie/modeling_ernie.py | 7 +++++++ .../models/megatron_bert/modeling_megatron_bert.py | 4 ++++ src/transformers/models/rembert/modeling_rembert.py | 2 ++ src/transformers/models/roberta/modeling_roberta.py | 12 ++++++++++++ .../modeling_roberta_prelayernorm.py | 4 ++++ .../models/roc_bert/modeling_roc_bert.py | 7 +++++++ src/transformers/models/tapas/modeling_tapas.py | 2 ++ .../models/xlm_roberta/modeling_xlm_roberta.py | 3 +++ .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 9 +++++++++ src/transformers/models/xmod/modeling_xmod.py | 3 +++ 16 files changed, 95 insertions(+) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index ec56b8469e41..3849b5ea9ec7 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -46,6 +46,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, get_torch_version, logging +from utils.deprecation import deprecate_kwarg from .configuration_bert import BertConfig @@ -217,12 +218,14 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -237,6 +240,8 @@ def forward( # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -335,12 +340,14 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from BertSelfAttention + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -359,6 +366,7 @@ def forward( attention_mask, head_mask, encoder_hidden_states, + encoder_attention_mask, past_key_value, output_attentions, cache_position, @@ -373,6 +381,8 @@ def forward( # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention # mask needs to be such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask current_states = encoder_hidden_states if is_cross_attention else hidden_states if past_key_value is not None: @@ -496,12 +506,14 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -511,6 +523,7 @@ def forward( attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, cache_position=cache_position, diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 02fbb04cd659..c9a59fba16a5 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -82,12 +82,14 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -102,6 +104,8 @@ def forward( # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -228,12 +232,14 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -243,6 +249,7 @@ def forward( attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, cache_position=cache_position, diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index fc72dacb3402..8216b630883d 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -435,6 +435,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -449,6 +450,8 @@ def forward( # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -575,12 +578,14 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -590,6 +595,7 @@ def forward( attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, cache_position=cache_position, diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index e49c775b8a31..bcb9140c1fc0 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -173,6 +173,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -187,6 +188,8 @@ def forward( # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index e37b7ec756c4..f8ad6fb78f2b 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -173,6 +173,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -187,6 +188,8 @@ def forward( # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -328,12 +331,14 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -343,6 +348,7 @@ def forward( attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, cache_position=cache_position, diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index ef38d43a3e01..d5da7f766e94 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -228,12 +228,14 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -248,6 +250,8 @@ def forward( # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -389,12 +393,14 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -404,6 +410,7 @@ def forward( attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, cache_position=cache_position, diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 287eea8c7016..b696e1a260c6 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -153,12 +153,14 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -173,6 +175,8 @@ def forward( # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -314,12 +318,14 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -329,6 +335,7 @@ def forward( attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, cache_position=cache_position, diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 45153d5ddbc6..a044a83a0711 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -206,12 +206,14 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -226,6 +228,8 @@ def forward( # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 565771b1efc4..23a8918656ff 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -361,6 +361,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -370,6 +371,7 @@ def forward( attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, cache_position=cache_position, diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index ee1de37d0244..fd5822aa334c 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -166,12 +166,14 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -186,6 +188,8 @@ def forward( # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -285,12 +289,14 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from RobertaSelfAttention + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -309,6 +315,7 @@ def forward( attention_mask, head_mask, encoder_hidden_states, + encoder_attention_mask, past_key_value, output_attentions, cache_position, @@ -323,6 +330,8 @@ def forward( # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention # mask needs to be such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask current_states = encoder_hidden_states if is_cross_attention else hidden_states if past_key_value is not None: @@ -448,12 +457,14 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -463,6 +474,7 @@ def forward( attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, cache_position=cache_position, diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 604c2d91b5f7..35ef4cf3e1b4 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -165,12 +165,14 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -185,6 +187,8 @@ def forward( # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 487ee3c28754..afbe38cbff40 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -280,12 +280,14 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -300,6 +302,8 @@ def forward( # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -441,12 +445,14 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -456,6 +462,7 @@ def forward( attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, cache_position=cache_position, diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 6fb656ebdcf0..11738f08cb1e 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -438,6 +438,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -447,6 +448,7 @@ def forward( attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, cache_position=cache_position, diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 02997dde4faa..e15200605927 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -173,6 +173,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -187,6 +188,8 @@ def forward( # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 4f1c2a02bacb..f1a4f2252b12 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -164,12 +164,14 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -184,6 +186,8 @@ def forward( # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -283,12 +287,14 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from XLMRobertaXLSelfAttention + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -307,6 +313,7 @@ def forward( attention_mask, head_mask, encoder_hidden_states, + encoder_attention_mask, past_key_value, output_attentions, cache_position, @@ -321,6 +328,8 @@ def forward( # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention # mask needs to be such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask current_states = encoder_hidden_states if is_cross_attention else hidden_states if past_key_value is not None: diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 88567b6e94a3..c2bc4ed123ed 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -170,6 +170,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -184,6 +185,8 @@ def forward( # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): From 8218c5bcdb8d40ef7fca387f26505edd23e0df5d Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 10 Jul 2025 10:57:38 +0200 Subject: [PATCH 46/58] import `deprecate_kwarg` where needed --- src/transformers/models/bert/modeling_bert.py | 2 +- .../models/bert_generation/modeling_bert_generation.py | 6 ++---- .../models/bridgetower/modeling_bridgetower.py | 2 ++ .../models/camembert/modeling_camembert.py | 10 ++++++++++ .../models/data2vec/modeling_data2vec_text.py | 2 ++ src/transformers/models/electra/modeling_electra.py | 7 ++----- src/transformers/models/ernie/modeling_ernie.py | 1 + .../models/megatron_bert/modeling_megatron_bert.py | 1 + src/transformers/models/roberta/modeling_roberta.py | 1 + .../modeling_roberta_prelayernorm.py | 1 + src/transformers/models/roc_bert/modeling_roc_bert.py | 1 + .../models/xlm_roberta/modeling_xlm_roberta.py | 10 ++++++++++ .../models/xlm_roberta_xl/modeling_xlm_roberta_xl.py | 1 + src/transformers/models/xmod/modeling_xmod.py | 2 ++ 14 files changed, 37 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 3849b5ea9ec7..d5e145c7aed9 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -46,7 +46,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, get_torch_version, logging -from utils.deprecation import deprecate_kwarg +from ...utils.deprecation import deprecate_kwarg from .configuration_bert import BertConfig diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index c9a59fba16a5..b287f00425b5 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -28,10 +28,8 @@ from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import ( - auto_docstring, - logging, -) +from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_bert_generation import BertGenerationConfig diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 8216b630883d..e60b07478199 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -37,6 +37,7 @@ from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int +from ...utils.deprecation import deprecate_kwarg from .configuration_bridgetower import BridgeTowerConfig, BridgeTowerTextConfig, BridgeTowerVisionConfig @@ -429,6 +430,7 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index bcb9140c1fc0..4568d463ce08 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -42,6 +42,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_camembert import CamembertConfig @@ -167,6 +168,7 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -289,12 +291,14 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from CamembertSelfAttention + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -313,6 +317,7 @@ def forward( attention_mask, head_mask, encoder_hidden_states, + encoder_attention_mask, past_key_value, output_attentions, cache_position, @@ -327,6 +332,8 @@ def forward( # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention # mask needs to be such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask current_states = encoder_hidden_states if is_cross_attention else hidden_states if past_key_value is not None: @@ -452,12 +459,14 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -467,6 +476,7 @@ def forward( attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, cache_position=cache_position, diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index f8ad6fb78f2b..29f4004d4d5e 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -39,6 +39,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_data2vec_text import Data2VecTextConfig @@ -167,6 +168,7 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index d5da7f766e94..041cc804b4d3 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -40,11 +40,8 @@ ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import ( - ModelOutput, - auto_docstring, - logging, -) +from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_electra import ElectraConfig diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index b696e1a260c6..e822dbbd9ea3 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -42,6 +42,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_ernie import ErnieConfig diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index a044a83a0711..469027d70ca1 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -44,6 +44,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_megatron_bert import MegatronBertConfig diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index fd5822aa334c..c6f2c8d83f93 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -42,6 +42,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_roberta import RobertaConfig diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 35ef4cf3e1b4..93134668df30 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -40,6 +40,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_roberta_prelayernorm import RobertaPreLayerNormConfig diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index afbe38cbff40..80fad1ddfe70 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -40,6 +40,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_roc_bert import RoCBertConfig diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index e15200605927..d2e2dfa56efb 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -42,6 +42,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_xlm_roberta import XLMRobertaConfig @@ -167,6 +168,7 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -289,12 +291,14 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from XLMRobertaSelfAttention + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -313,6 +317,7 @@ def forward( attention_mask, head_mask, encoder_hidden_states, + encoder_attention_mask, past_key_value, output_attentions, cache_position, @@ -327,6 +332,8 @@ def forward( # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention # mask needs to be such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask current_states = encoder_hidden_states if is_cross_attention else hidden_states if past_key_value is not None: @@ -452,12 +459,14 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -467,6 +476,7 @@ def forward( attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, cache_position=cache_position, diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index f1a4f2252b12..abf521e8a67e 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -41,6 +41,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_xlm_roberta_xl import XLMRobertaXLConfig diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index c2bc4ed123ed..e76c0cfecae7 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -39,6 +39,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_xmod import XmodConfig @@ -164,6 +165,7 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, From bb1866ca58463ba791cbe65bde6e096c2faca48a Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 10 Jul 2025 11:57:30 +0200 Subject: [PATCH 47/58] fix copies again --- src/transformers/models/rembert/modeling_rembert.py | 6 ++++++ .../roberta_prelayernorm/modeling_roberta_prelayernorm.py | 3 +++ src/transformers/models/tapas/modeling_tapas.py | 6 ++++++ .../models/xlm_roberta_xl/modeling_xlm_roberta_xl.py | 3 +++ src/transformers/models/xmod/modeling_xmod.py | 3 +++ 5 files changed, 21 insertions(+) diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 23a8918656ff..118f9bcfebe7 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -40,6 +40,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_rembert import RemBertConfig @@ -221,12 +222,14 @@ def __init__(self, config, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, @@ -242,6 +245,8 @@ def forward( # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -354,6 +359,7 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") # Copied from transformers.models.bert.modeling_bert.BertAttention.forward def forward( self, diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 93134668df30..3d95e38c3d8c 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -323,12 +323,14 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -339,6 +341,7 @@ def forward( attention_mask, head_mask, encoder_hidden_states, + encoder_attention_mask, past_key_value, output_attentions, cache_position, diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 11738f08cb1e..789bc6b72b68 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -32,6 +32,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_tapas import TapasConfig @@ -300,12 +301,14 @@ def __init__(self, config, layer_idx=None): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, + encoder_attention_mask=None, past_key_value=None, output_attentions=False, cache_position=None, @@ -321,6 +324,8 @@ def forward( # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -431,6 +436,7 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") # Copied from transformers.models.bert.modeling_bert.BertAttention.forward def forward( self, diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index abf521e8a67e..e0d928152ab5 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -454,12 +454,14 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, + encoder_attention_mask=None, past_key_value=None, output_attentions=False, cache_position=None, @@ -470,6 +472,7 @@ def forward( attention_mask, head_mask, encoder_hidden_states, + encoder_attention_mask, past_key_value, output_attentions, cache_position, diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index e76c0cfecae7..a5c5f1489acc 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -322,12 +322,14 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -340,6 +342,7 @@ def forward( attention_mask, head_mask, encoder_hidden_states, + encoder_attention_mask, past_key_value, output_attentions, cache_position, From b67a4c3fd686648df5759fac6091f9d76de24905 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 14 Jul 2025 16:43:55 +0200 Subject: [PATCH 48/58] fix copies --- src/transformers/models/lfm2/modeling_lfm2.py | 2 -- .../models/perception_lm/modeling_perception_lm.py | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index 4931a3a46e04..b7a8a398921a 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -543,8 +543,6 @@ class Lfm2PreTrainedModel(PreTrainedModel): _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True _supports_static_cache = False _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/perception_lm/modeling_perception_lm.py b/src/transformers/models/perception_lm/modeling_perception_lm.py index 4e00bed6d6c1..cc5d52d95110 100644 --- a/src/transformers/models/perception_lm/modeling_perception_lm.py +++ b/src/transformers/models/perception_lm/modeling_perception_lm.py @@ -91,11 +91,11 @@ class PerceptionLMPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn_2 = True _supports_flash_attn_3 = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_flex_attn = True _supports_attention_backend = True From 5eeeeb32c283ee67b6cd3392ad4a7022eef55ded Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 15 Jul 2025 12:02:12 +0200 Subject: [PATCH 49/58] delete `nex_decoder_cache` --- .../models/autoformer/modeling_autoformer.py | 29 +++++++------ src/transformers/models/bark/modeling_bark.py | 33 ++++----------- src/transformers/models/bart/modeling_bart.py | 23 ++++------ src/transformers/models/bert/modeling_bert.py | 34 ++++----------- .../modeling_bert_generation.py | 29 +++---------- .../models/big_bird/modeling_big_bird.py | 38 ++++------------- .../modeling_bigbird_pegasus.py | 30 ++++--------- .../models/biogpt/modeling_biogpt.py | 18 +++----- .../models/biogpt/modular_biogpt.py | 16 ++----- .../models/blenderbot/modeling_blenderbot.py | 29 ++++--------- .../modeling_blenderbot_small.py | 22 ++++------ .../models/blip/modeling_blip_text.py | 25 ++++------- .../bridgetower/modeling_bridgetower.py | 26 +++--------- .../models/camembert/modeling_camembert.py | 34 ++++----------- .../models/chameleon/modeling_chameleon.py | 29 ++++--------- src/transformers/models/clvp/modeling_clvp.py | 38 ++++------------- .../models/cpmant/modeling_cpmant.py | 32 +++++++------- src/transformers/models/ctrl/modeling_ctrl.py | 21 ++++------ .../models/data2vec/modeling_data2vec_text.py | 29 +++---------- .../models/electra/modeling_electra.py | 29 +++---------- .../models/ernie/modeling_ernie.py | 30 +++---------- src/transformers/models/fsmt/modeling_fsmt.py | 16 +++---- .../gpt_bigcode/modeling_gpt_bigcode.py | 26 ++++-------- .../models/ibert/modeling_ibert.py | 4 -- .../models/imagegpt/modeling_imagegpt.py | 29 ++++--------- .../models/informer/modeling_informer.py | 24 ++++------- .../models/informer/modular_informer.py | 2 +- .../models/kosmos2/modeling_kosmos2.py | 22 +++------- src/transformers/models/led/modeling_led.py | 12 ++---- .../models/m2m_100/modeling_m2m_100.py | 30 ++++--------- .../models/marian/modeling_marian.py | 23 ++++------ .../models/mbart/modeling_mbart.py | 30 ++++--------- .../megatron_bert/modeling_megatron_bert.py | 33 ++++----------- src/transformers/models/mpt/modeling_mpt.py | 27 ++++-------- .../models/musicgen/modeling_musicgen.py | 22 +++------- .../modeling_musicgen_melody.py | 29 ++++--------- src/transformers/models/mvp/modeling_mvp.py | 30 ++++--------- .../models/nllb_moe/modeling_nllb_moe.py | 25 +++++------ .../models/pegasus/modeling_pegasus.py | 29 ++++--------- .../models/pegasus_x/modeling_pegasus_x.py | 21 +++------- .../models/plbart/modeling_plbart.py | 23 ++++------ .../models/prophetnet/modeling_prophetnet.py | 25 ++++------- .../qwen2_audio/modeling_qwen2_audio.py | 11 ++--- .../models/rembert/modeling_rembert.py | 29 +++---------- .../models/roberta/modeling_roberta.py | 34 ++++----------- .../modeling_roberta_prelayernorm.py | 29 +++---------- .../models/roc_bert/modeling_roc_bert.py | 29 +++---------- .../models/roformer/modeling_roformer.py | 33 ++++----------- .../seamless_m4t/modeling_seamless_m4t.py | 31 +++++--------- .../modeling_seamless_m4t_v2.py | 42 ++++++------------- .../speech_to_text/modeling_speech_to_text.py | 29 ++++--------- .../models/speecht5/modeling_speecht5.py | 23 ++++------ .../models/tapas/modeling_tapas.py | 13 +----- .../modeling_time_series_transformer.py | 22 ++++------ .../models/trocr/modeling_trocr.py | 20 +++------ .../models/whisper/modeling_whisper.py | 21 +++------- src/transformers/models/xglm/modeling_xglm.py | 20 +++------ .../xlm_roberta/modeling_xlm_roberta.py | 34 ++++----------- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 38 ++++------------- src/transformers/models/xmod/modeling_xmod.py | 33 ++++----------- 60 files changed, 413 insertions(+), 1154 deletions(-) diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 6dbda6610f51..974d2a5e4d3c 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -628,7 +628,7 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped class AutoformerEncoderLayer(GradientCheckpointingLayer): @@ -670,7 +670,7 @@ def forward( returned tensors for more detail. """ residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -788,7 +788,7 @@ def forward( residual = hidden_states # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -807,7 +807,7 @@ def forward( if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -842,9 +842,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -1179,7 +1176,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1214,9 +1210,6 @@ def forward( (hidden_states, residual_trend) = layer_outputs[0] trend = trend + residual_trend - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1230,20 +1223,26 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, trend, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [ + hidden_states, + trend, + past_key_values, + all_hidden_states, + all_self_attns, + all_cross_attentions, + ] if v is not None ) return AutoFormerDecoderOutput( last_hidden_state=hidden_states, trend=trend, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 80f8b4e07054..f3eef832fa12 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -175,11 +175,7 @@ def forward( attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, past_key_values) - if output_attentions: - outputs += (attn_weights,) - - return outputs + return attn_output, attn_weights class BarkSelfFlashAttention2(BarkSelfAttention): @@ -253,12 +249,7 @@ def forward( attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, cache_position) - if output_attentions: - attn_weights = None - outputs += (attn_weights,) - - return outputs + return attn_output, None BARK_ATTENTION_CLASSES = { @@ -332,12 +323,7 @@ def forward( self.layernorm_2(intermediary_hidden_states) ) - if use_cache: - outputs = (intermediary_hidden_states,) + outputs - else: - outputs = (intermediary_hidden_states,) + outputs[1:] - - return outputs # hidden_states, ((present), attentions) + return (intermediary_hidden_states,) + outputs @auto_docstring @@ -583,7 +569,6 @@ def forward( hidden_states = self.drop(input_embeds + position_embeds) output_shape = input_shape + (hidden_states.size(-1),) - next_decoder_cache = None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None @@ -603,11 +588,8 @@ def forward( hidden_states = outputs[0] - if use_cache: - next_decoder_cache = outputs[1] - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) hidden_states = self.layernorm_final(hidden_states) @@ -619,19 +601,18 @@ def forward( logits = self.lm_head(hidden_states) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( - v for v in [None, logits, next_cache, all_hidden_states, all_self_attentions] if v is not None + v for v in [None, logits, past_key_values, all_hidden_states, all_self_attentions] if v is not None ) return CausalLMOutputWithPast( loss=loss, logits=logits, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, ) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 4318bc43dd67..8cf79189f88c 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -268,7 +268,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class BartEncoderLayer(GradientCheckpointingLayer): @@ -310,7 +310,7 @@ def forward( returned tensors for more detail. """ residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -411,7 +411,7 @@ def forward( residual = hidden_states # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -428,7 +428,7 @@ def forward( if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -455,9 +455,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -1110,7 +1107,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1143,10 +1139,6 @@ def forward( cache_position=cache_position, ) hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1157,19 +1149,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index d5e145c7aed9..48832a1cf090 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -326,11 +326,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs class BertSdpaSelfAttention(BertSelfAttention): @@ -451,10 +447,7 @@ def forward( attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) - outputs = (attn_output,) - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return attn_output, None class BertSelfOutput(nn.Module): @@ -597,12 +590,7 @@ def forward( cache_position=cache_position, ) attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): @@ -621,17 +609,13 @@ def forward( cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -682,7 +666,6 @@ def forward( return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -701,8 +684,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -711,16 +692,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -729,7 +709,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index b287f00425b5..9dd0f3931101 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -188,11 +188,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs BERT_GENERATION_SELF_ATTENTION_CLASSES = { @@ -326,12 +322,7 @@ def forward( cache_position=cache_position, ) attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): @@ -350,17 +341,13 @@ def forward( cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -411,7 +398,6 @@ def forward( return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -430,8 +416,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -440,16 +424,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -458,7 +441,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index f5dab19dd0ba..354f341dd72a 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -386,11 +386,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs class BigBirdBlockSparseAttention(nn.Module): @@ -479,9 +475,7 @@ def forward( ) context_layer = context_layer.contiguous().view(batch_size, from_seq_length, -1) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - return outputs + return context_layer, attention_probs @staticmethod def torch_bmm_nd(inp_1, inp_2, ndim=None): @@ -1486,12 +1480,7 @@ def forward( cache_position=cache_position, ) attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): @@ -1510,19 +1499,13 @@ def forward( cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) - outputs = (layer_output,) + outputs - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (past_key_value,) - - return outputs + return (layer_output,) + outputs def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) @@ -1592,8 +1575,6 @@ def forward( return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) - next_decoder_cache = None - for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -1616,8 +1597,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -1626,16 +1605,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -1644,7 +1622,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 5f9d3264d2c8..9b5a698b1174 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -197,11 +197,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.big_bird.modeling_big_bird.BigBirdBlockSparseAttention with BigBird->BigBirdPegasus @@ -291,9 +287,7 @@ def forward( ) context_layer = context_layer.contiguous().view(batch_size, from_seq_length, -1) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - return outputs + return context_layer, attention_probs @staticmethod def torch_bmm_nd(inp_1, inp_2, ndim=None): @@ -1331,7 +1325,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class BigBirdPegasusEncoderLayer(GradientCheckpointingLayer): @@ -1492,7 +1486,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -1509,7 +1503,7 @@ def forward( residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -1534,9 +1528,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -2265,7 +2256,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -2298,9 +2288,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -2313,19 +2300,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index d90fc755342a..8e74a59b7a07 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -245,7 +245,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class BioGptDecoderLayer(GradientCheckpointingLayer): @@ -307,7 +307,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -335,9 +335,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -636,7 +633,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = None - next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) @@ -661,9 +657,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -673,19 +666,18 @@ def forward( hidden_states = self.layer_norm(hidden_states) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index d0145b37b370..8ae8e474b97a 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -132,7 +132,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -160,9 +160,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -461,7 +458,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = None - next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) @@ -486,9 +482,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -498,19 +491,18 @@ def forward( hidden_states = self.layer_norm(hidden_states) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 84ce712e36e3..f3dbcdaeea15 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -267,7 +267,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT @@ -310,7 +310,7 @@ def forward( """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -331,12 +331,7 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states, attn_weights # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot, MBART->BLENDERBOT @@ -410,7 +405,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -427,7 +422,7 @@ def forward( residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -452,9 +447,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -1064,7 +1056,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1097,9 +1088,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1113,19 +1101,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 99d6f550a01b..d2f4f6120840 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -251,7 +251,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL @@ -294,7 +294,7 @@ def forward( returned tensors for more detail. """ residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -396,7 +396,7 @@ def forward( residual = hidden_states # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -413,7 +413,7 @@ def forward( if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -440,9 +440,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -1047,7 +1044,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1080,9 +1076,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1093,19 +1086,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 0123ff73e20c..61dbcba09980 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -241,10 +241,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert -> BlipText @@ -378,7 +375,7 @@ def forward( cache_position=cache_position, ) attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[1:-1] + outputs = self_attention_outputs[1:] if encoder_hidden_states is not None: cross_attention_outputs = self.crossattention( @@ -391,15 +388,11 @@ def forward( cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) - outputs = (layer_output,) + outputs - - outputs = outputs + (past_key_value,) - - return outputs + return (layer_output,) + outputs def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) @@ -454,7 +447,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.is_decoder else None - next_decoder_cache = None for i in range(self.config.num_hidden_layers): layer_module = self.layer[i] @@ -475,8 +467,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) all_cross_attentions = all_cross_attentions + (layer_outputs[2],) @@ -484,16 +474,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -502,7 +491,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index e60b07478199..47f68fe23c0a 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -538,11 +538,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs BRIDGE_TOWER_SELF_ATTENTION_CLASSES = { @@ -655,7 +651,7 @@ def forward( ) attention_output = cross_attention_outputs[0] # add cross attentions if we output attention weights - outputs = outputs + cross_attention_outputs[1:-1] + outputs = outputs + cross_attention_outputs[1:] layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output @@ -735,13 +731,7 @@ def forward( layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) - outputs = (layer_output,) + outputs - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (past_key_value,) - - return outputs + return (layer_output,) + outputs def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) @@ -794,7 +784,6 @@ def forward( return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -813,8 +802,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -823,16 +810,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -841,7 +827,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 4568d463ce08..dcb0d243d031 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -276,11 +276,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.roberta.modeling_roberta.RobertaSdpaSelfAttention with Roberta->Camembert @@ -402,10 +398,7 @@ def forward( attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) - outputs = (attn_output,) - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return attn_output, None # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput with Roberta->Camembert @@ -553,12 +546,7 @@ def forward( cache_position=cache_position, ) attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): @@ -577,17 +565,13 @@ def forward( cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -639,7 +623,6 @@ def forward( return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -658,8 +641,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -668,16 +649,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -686,7 +666,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 69562c428976..24e6eee6784a 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -377,10 +377,7 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON @@ -430,7 +427,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -453,9 +450,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -504,7 +498,7 @@ def forward( residual = hidden_states # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -526,9 +520,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -1009,7 +1000,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -1028,9 +1018,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1040,16 +1027,14 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = next_decoder_cache - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index e46b31fc8f2c..36220f56aaca 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -383,10 +383,7 @@ def forward( attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - - return attn_output, past_key_value, attn_weights + return attn_output, attn_weights class ClvpGatedLinearUnit(nn.Module): @@ -462,7 +459,7 @@ def forward( hidden_states = self.input_rmsnorm(hidden_states) - attention_outputs = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, rotary_pos_emb=rotary_pos_emb, attention_mask=attention_mask, @@ -470,8 +467,6 @@ def forward( output_attentions=output_attentions, ) - hidden_states = attention_outputs[0] - hidden_states = residual + hidden_states residual = hidden_states @@ -479,12 +474,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (attention_outputs[-1],) - - return outputs + return hidden_states, attn_weights # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Clvp @@ -641,7 +631,6 @@ def forward( cache_position=cache_position, ) attn_output = attn_outputs[0] - outputs = attn_outputs[1:] # residual connection hidden_states = attn_output + residual @@ -651,12 +640,7 @@ def forward( # residual connection hidden_states = residual + feed_forward_hidden_states - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] - - return outputs + return (hidden_states,) + attn_outputs[1:] class ClvpConditioningEncoder(nn.Module): @@ -1120,7 +1104,6 @@ def forward( output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) - next_decoder_cache = None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None @@ -1151,13 +1134,11 @@ def forward( ) hidden_states = outputs[0] - if use_cache is True: - next_decoder_cache = outputs[1] if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + all_cross_attentions = all_cross_attentions + (outputs[2],) hidden_states = self.layer_norm(hidden_states) @@ -1167,20 +1148,19 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/cpmant/modeling_cpmant.py b/src/transformers/models/cpmant/modeling_cpmant.py index b48371f20ff9..d68521c6dd52 100755 --- a/src/transformers/models/cpmant/modeling_cpmant.py +++ b/src/transformers/models/cpmant/modeling_cpmant.py @@ -158,7 +158,7 @@ def forward( score = self.attention_out(score) - return score, attn_weights, past_key_values + return score, attn_weights class CpmAntSelfAttentionBlock(nn.Module): @@ -198,7 +198,7 @@ def forward( (see `past_key_values`). """ outputs = self.layernorm_before_attention(hidden_states) - outputs = self.self_attention( + outputs, attn_weights = self.self_attention( outputs, outputs, attention_mask, @@ -209,13 +209,11 @@ def forward( cache_position, ) - outputs, attn_weights, current_key_value = outputs - if self.dropout is not None: outputs = self.dropout(outputs) hidden_states = hidden_states + outputs - return hidden_states, attn_weights, current_key_value + return hidden_states, attn_weights class CpmAntDenseGatedACT(nn.Module): @@ -323,7 +321,7 @@ def forward( If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). """ - hidden_states = self.self_att( + hidden_states, attn_weights = self.self_att( hidden_states, attention_mask=attention_mask, position_bias=position_bias, @@ -333,11 +331,8 @@ def forward( cache_position=cache_position, ) - hidden_states, attn_weights, current_key_value = hidden_states - hidden_states = self.ffn(hidden_states) - - return hidden_states, attn_weights, current_key_value + return hidden_states, attn_weights class CpmAntEncoder(nn.Module): @@ -388,10 +383,10 @@ def forward( attention_mask, position_bias, output_attentions=output_attentions, - past_key_values=past_key_values[i] if past_key_values else None, + past_key_values=past_key_values, use_cache=use_cache, ) - hidden_states, attn_weights, past_key_values = layer_outputs + hidden_states, attn_weights = layer_outputs if output_attentions: all_self_attns += (attn_weights,) @@ -400,7 +395,7 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - return hidden_states, past_key_values, all_hidden_states, all_self_attns + return hidden_states, all_hidden_states, all_self_attns # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->CPMAnt @@ -667,7 +662,7 @@ def forward( position_bias = position_bias[:, :, past_length:, :] hidden_states = hidden_states[:, past_length:, :] - hidden_states, next_decoder_cache, all_hidden_states, all_attentions = self.encoder( + hidden_states, all_hidden_states, all_attentions = self.encoder( hidden_states, attention_mask, position_bias, @@ -692,16 +687,17 @@ def forward( new_hidden_states += (hidden_state[:, self.prompt_length :, :],) all_hidden_states = new_hidden_states - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attentions, ) diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index ca7487700c7e..ba1a737efc5d 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -151,11 +151,7 @@ def forward( attn = output[1] original_size_attention = scaled_attention.reshape(batch_size, -1, self.d_model_size) output = self.dense(original_size_attention) - - outputs = (output, layer_past) - if output_attentions: - outputs = outputs + (attn,) - return outputs + return output, attn def point_wise_feed_forward_network(d_model_size, dff): @@ -403,7 +399,6 @@ def forward( hidden_states = self.dropout(hidden_states) - next_decoder_cache = None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None for i, h in enumerate(self.h): @@ -420,26 +415,24 @@ def forward( cache_position=cache_position, ) hidden_states = outputs[0] - if use_cache is True: - next_decoder_cache = outputs[1] - if output_attentions: - all_attentions += (outputs[2],) + all_attentions += (outputs[1],) hidden_states = self.layernorm(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attentions, ) diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 29f4004d4d5e..93d217eed128 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -276,11 +276,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.bert.modeling_bert.BertSelfOutput @@ -429,12 +425,7 @@ def forward( cache_position=cache_position, ) attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): @@ -453,17 +444,13 @@ def forward( cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -515,7 +502,6 @@ def forward( return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -534,8 +520,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -544,16 +528,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -562,7 +545,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 041cc804b4d3..c22b47c55ad4 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -333,11 +333,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.bert.modeling_bert.BertSelfOutput @@ -484,12 +480,7 @@ def forward( cache_position=cache_position, ) attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): @@ -508,17 +499,13 @@ def forward( cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -570,7 +557,6 @@ def forward( return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -589,8 +575,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -599,16 +583,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -617,7 +600,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index e822dbbd9ea3..a5d55bafd754 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -262,11 +262,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->Ernie @@ -413,12 +409,7 @@ def forward( cache_position=cache_position, ) attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): @@ -437,17 +428,12 @@ def forward( cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -499,7 +485,6 @@ def forward( return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -518,8 +503,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -528,16 +511,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -546,7 +528,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 8ed9feb35069..52f3d027c80c 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -537,9 +537,8 @@ def forward( return ( x, self_attn_weights, - layer_state, cross_attn_weights, - ) # layer_state = cache for decoding + ) class FSMTDecoder(nn.Module): @@ -671,7 +670,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attns = () if output_attentions else None - next_decoder_cache = None # check if head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -691,7 +689,7 @@ def forward( if dropout_probability < self.layerdrop: continue - x, layer_self_attn, layer_past, layer_cross_attn = decoder_layer( + x, layer_self_attn, layer_cross_attn = decoder_layer( x, encoder_hidden_states, encoder_attn_mask=encoder_padding_mask, @@ -704,9 +702,6 @@ def forward( cache_position=cache_position, ) - if use_cache: - next_decoder_cache = layer_past - if output_attentions: all_self_attns += (layer_self_attn,) all_cross_attns += (layer_cross_attn,) @@ -723,17 +718,16 @@ def forward( x = self.output_projection(x) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( - v for v in [x, next_cache, all_hidden_states, all_self_attns, all_cross_attns] if v is not None + v for v in [x, past_key_values, all_hidden_states, all_self_attns, all_cross_attns] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=x, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attns, diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 77ef761a435a..e81018962bff 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -251,7 +251,7 @@ def forward( attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) - return attn_output, layer_past, attn_weights + return attn_output, attn_weights class GPTBigCodeMLP(nn.Module): @@ -348,20 +348,13 @@ def forward( attn_output = cross_attn_outputs[0] # residual connection hidden_states = residual + attn_output - outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights residual = hidden_states hidden_states = self.ln_2(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states) - # residual connection hidden_states = residual + feed_forward_hidden_states - - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] - - return outputs # hidden_states, present, (attentions, cross_attentions) + return (hidden_states,) + outputs[1:] @auto_docstring @@ -564,7 +557,6 @@ def forward( hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) - next_decoder_cache = None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None @@ -586,13 +578,10 @@ def forward( ) hidden_states = outputs[0] - if use_cache: - next_decoder_cache = outputs[1] - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + all_cross_attentions = all_cross_attentions + (outputs[2],) hidden_states = self.ln_f(hidden_states) @@ -601,13 +590,12 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/ibert/modeling_ibert.py b/src/transformers/models/ibert/modeling_ibert.py index 1a58783e80d2..5d9c9b17e496 100644 --- a/src/transformers/models/ibert/modeling_ibert.py +++ b/src/transformers/models/ibert/modeling_ibert.py @@ -564,7 +564,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = None # `config.add_cross_attention` is not supported - next_decoder_cache = None # `config.use_cache` is not supported for i, layer_module in enumerate(self.layer): if output_hidden_states: @@ -592,7 +591,6 @@ def forward( v for v in [ hidden_states, - next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -601,7 +599,6 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -765,7 +762,6 @@ def forward( return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 85716494aeb1..2041c615dfab 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -399,11 +399,7 @@ def forward( attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, layer_past) - if output_attentions: - outputs += (attn_weights,) - - return outputs # a, layer_past, (attentions) + return attn_output, attn_weights class ImageGPTMLP(nn.Module): @@ -462,7 +458,7 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + attn_output = attn_outputs[0] outputs = attn_outputs[1:] # residual connection hidden_states = attn_output + residual @@ -489,7 +485,7 @@ def forward( attn_output = cross_attn_outputs[0] # residual connection hidden_states = residual + attn_output - outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights residual = hidden_states hidden_states = self.ln_2(hidden_states) @@ -497,9 +493,7 @@ def forward( # residual connection hidden_states = residual + feed_forward_hidden_states - outputs = (hidden_states,) + (outputs if use_cache else outputs[1:]) - - return outputs # hidden_states, present, (attentions, cross_attentions) + return (hidden_states,) + outputs @auto_docstring @@ -720,7 +714,6 @@ def forward( hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) - next_decoder_cache = None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None @@ -749,13 +742,10 @@ def forward( ) hidden_states = outputs[0] - if use_cache is True: - next_decoder_cache = outputs[1] - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + all_cross_attentions = all_cross_attentions + (outputs[2],) # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: @@ -770,20 +760,19 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 9718e8fb736e..abd55a22bfd5 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -522,7 +522,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class InformerProbSparseAttention(nn.Module): @@ -741,7 +741,7 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped # source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py @@ -814,7 +814,7 @@ def forward( returned tensors for more detail. """ residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -924,7 +924,7 @@ def forward( residual = hidden_states # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -941,7 +941,7 @@ def forward( if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -968,9 +968,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -1268,7 +1265,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1302,9 +1298,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1315,19 +1308,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py index 3d46275bdc81..79d7c661141f 100644 --- a/src/transformers/models/informer/modular_informer.py +++ b/src/transformers/models/informer/modular_informer.py @@ -430,7 +430,7 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped # source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index eabf614adf34..d6dba1087d08 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -790,7 +790,7 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class Kosmos2TextFFN(nn.Module): @@ -865,8 +865,7 @@ def forward( residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -890,7 +889,7 @@ def forward( residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -916,10 +915,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - - if use_cache: - outputs += (past_key_value,) - return outputs @@ -1093,7 +1088,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1128,9 +1122,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1140,9 +1131,8 @@ def forward( # add final layer norm hidden_states = self.layer_norm(hidden_states) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() # add hidden states from the last decoder layer if output_hidden_states: @@ -1150,7 +1140,7 @@ def forward( return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -1543,7 +1533,7 @@ def forward(self, features): latent_query = self.latent_query.unsqueeze(0).expand(hidden_states.size(0), -1, -1) key_value_states = torch.cat([hidden_states, latent_query], dim=1) - hidden_states, attn_weights, _ = self.x_attn( + hidden_states, attn_weights = self.x_attn( hidden_states=latent_query, encoder_hidden_states=key_value_states, past_key_value=None, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 4c6099884f84..b665858862ab 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1801,7 +1801,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if output_attentions else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1834,10 +1833,6 @@ def forward( ) hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) all_cross_attentions += (layer_outputs[2],) @@ -1846,19 +1841,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 4b1c476f00b4..6bbe24222b2f 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -332,7 +332,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100, MBART->M2M100 @@ -375,7 +375,7 @@ def forward( """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -396,12 +396,7 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states, attn_weights # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100 @@ -475,7 +470,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -492,7 +487,7 @@ def forward( residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -516,10 +511,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - - if use_cache: - outputs += (past_key_value,) - return outputs @@ -1112,7 +1103,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if output_attentions else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1155,9 +1145,6 @@ def forward( if skip_the_layer: continue - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) all_cross_attentions += (layer_outputs[2],) @@ -1168,19 +1155,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 2c8013c2310c..60648c0fb8ab 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -267,7 +267,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->Marian, BART->MARIAN @@ -310,7 +310,7 @@ def forward( returned tensors for more detail. """ residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -412,7 +412,7 @@ def forward( residual = hidden_states # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -429,7 +429,7 @@ def forward( if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -455,10 +455,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - - if use_cache: - outputs += (past_key_value,) - return outputs @@ -1062,7 +1058,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1094,9 +1089,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1107,19 +1099,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 41e84d204dfe..58b0083bdc9f 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -277,7 +277,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class MBartEncoderLayer(GradientCheckpointingLayer): @@ -319,7 +319,7 @@ def forward( """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -340,12 +340,7 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states, attn_weights class MBartDecoderLayer(GradientCheckpointingLayer): @@ -418,7 +413,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -435,7 +430,7 @@ def forward( residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -460,9 +455,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -1104,7 +1096,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1136,10 +1127,6 @@ def forward( cache_position=cache_position, ) hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1152,19 +1139,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 469027d70ca1..7ed94107b753 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -315,11 +315,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Based transformers.models.bert.modeling_bert.BertSelfOutput. Moved LayerNorm to MegatronBertAttention below. @@ -454,12 +450,7 @@ def forward( cache_position=cache_position, ) attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): @@ -478,18 +469,12 @@ def forward( cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) - outputs = (layer_output,) + outputs - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (past_key_value,) - - return outputs + return (layer_output,) + outputs def feed_forward_chunk(self, attention_output): ln_output = self.ln(attention_output) @@ -542,7 +527,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: @@ -565,8 +549,6 @@ def forward( # zed data here. If that's really needed, we must apply LN to match Transformer's BERT. hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -578,16 +560,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -596,7 +577,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index fc608fbe2eb7..81680bef7950 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -135,7 +135,7 @@ def forward( context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) attn_output = self.out_proj(context_states) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class MptMLP(nn.Module): @@ -197,7 +197,7 @@ def forward( residual = hidden_states # Self attention. - attn_outputs, attn_weights, past_key_value = self.attn( + attn_outputs, attn_weights = self.attn( layernorm_output, position_bias=position_bias, attention_mask=attention_mask, @@ -214,15 +214,7 @@ def forward( # MLP. output = self.ffn(layernorm_output, residual) - outputs = (output,) - - if use_cache: - outputs += (past_key_value,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs # hidden_states, present, attentions + return output, attn_weights @auto_docstring @@ -371,7 +363,6 @@ def forward( hidden_states = inputs_embeds - next_decoder_cache = None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None @@ -405,30 +396,26 @@ def forward( ) hidden_states = outputs[0] - if use_cache is True: - next_decoder_cache = outputs[1] - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) # Add last hidden state hidden_states = self.norm_f(hidden_states) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( - v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, ) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 1be460189fca..3f38b82e2b66 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -298,7 +298,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class MusicgenDecoderLayer(GradientCheckpointingLayer): @@ -372,7 +372,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -389,7 +389,7 @@ def forward( residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -414,10 +414,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - - if use_cache: - outputs += (past_key_value,) - return outputs @@ -606,7 +602,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -637,10 +632,6 @@ def forward( cache_position=cache_position, ) hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -653,19 +644,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 0be71b6c91cf..09b6dc79b2f8 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -306,7 +306,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class MusicgenMelodyDecoderLayer(GradientCheckpointingLayer): @@ -359,7 +359,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -378,16 +378,7 @@ def forward( hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (past_key_value,) - - return outputs + return hidden_states, self_attn_weights @auto_docstring @@ -578,7 +569,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None - next_decoder_cache = None # check if head_mask has a correct number of layers specified if desired if head_mask is not None: @@ -606,10 +596,6 @@ def forward( cache_position=cache_position, ) hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_attentions += (layer_outputs[1],) @@ -619,15 +605,16 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attentions, ) diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 2737b8f027a5..1223d23fba20 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -247,7 +247,7 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped class MvpEncoderLayer(GradientCheckpointingLayer): @@ -289,7 +289,7 @@ def forward( returned tensors for more detail. """ residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -314,12 +314,7 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states, attn_weights class MvpDecoderLayer(GradientCheckpointingLayer): @@ -391,7 +386,7 @@ def forward( residual = hidden_states # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -409,7 +404,7 @@ def forward( if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -436,9 +431,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -913,7 +905,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -948,10 +939,6 @@ def forward( cache_position=cache_position, ) hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -962,19 +949,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 26353bc98831..d42f587c72c9 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -620,7 +620,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class NllbMoeEncoderLayer(GradientCheckpointingLayer): @@ -666,7 +666,7 @@ def forward( """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -779,7 +779,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -796,7 +796,7 @@ def forward( residual = hidden_states hidden_states = self.cross_attention_layer_norm(hidden_states) - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attention( + hidden_states, cross_attn_weights = self.cross_attention( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, past_key_value=past_key_value, @@ -826,7 +826,7 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states, past_key_value) + outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) @@ -1264,7 +1264,6 @@ def forward( all_self_attns = () if output_attentions else None all_router_probs = () if output_router_logits else None all_cross_attentions = () if output_attentions else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1308,12 +1307,9 @@ def forward( if skip_the_layer: continue - if use_cache: - next_decoder_cache = layer_outputs[1] - if output_attentions: - all_self_attns += (layer_outputs[2],) - all_cross_attentions += (layer_outputs[3],) + all_self_attns += (layer_outputs[1],) + all_cross_attentions += (layer_outputs[2],) if output_router_logits: all_router_probs += (layer_outputs[-1],) @@ -1324,16 +1320,15 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_self_attns, all_cross_attentions, @@ -1343,7 +1338,7 @@ def forward( ) return MoEModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 39641eabc8df..3802c469ff37 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -266,7 +266,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS @@ -309,7 +309,7 @@ def forward( """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -330,12 +330,7 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states, attn_weights # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Pegasus, MBART->PEGASUS @@ -409,7 +404,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -426,7 +421,7 @@ def forward( residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -451,9 +446,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -1109,7 +1101,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1142,9 +1133,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1157,19 +1145,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 4032d534dde9..4cc2c1c27358 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -287,7 +287,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class PegasusXGlobalLocalAttention(nn.Module): @@ -705,7 +705,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -721,7 +721,7 @@ def forward( residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -744,10 +744,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - - if use_cache: - outputs += (past_key_value,) - return outputs @@ -1365,7 +1361,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) @@ -1388,9 +1383,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1403,19 +1395,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index d3660e688b85..258c94813756 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -464,7 +464,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class PLBartEncoderLayer(GradientCheckpointingLayer): @@ -506,7 +506,7 @@ def forward( returned tensors for more detail. """ residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -776,7 +776,7 @@ def forward( residual = hidden_states # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -793,7 +793,7 @@ def forward( if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -820,9 +820,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -1037,7 +1034,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1070,10 +1066,6 @@ def forward( cache_position=cache_position, ) hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1084,19 +1076,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 380da9e59b6f..d9c7807d5280 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -543,7 +543,7 @@ def forward( attn_output = self.out_proj(attn_output) attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped class ProphetNetFeedForward(nn.Module): @@ -784,7 +784,7 @@ def forward( attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training) - return attn_output, main_attn_probs, predict_attn_probs, past_key_value + return attn_output, main_attn_probs, predict_attn_probs def get_main_relative_pos_embeddings( self, hidden_states, attn_weights, position_ids, main_relative_position_buckets @@ -914,7 +914,7 @@ def forward( output_attentions: bool = False, ): # 1st residual block - attention_output, attn_weights, _ = self.self_attn( + attention_output, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -972,7 +972,7 @@ def forward( cache_position: Optional[torch.Tensor] = None, ): # 1st residual block - ngram_attention_output, self_attn_weights, self_attn_weights_ngram, past_key_value = self.self_attn( + ngram_attention_output, self_attn_weights, self_attn_weights_ngram = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -987,7 +987,7 @@ def forward( cross_attn_weights = None if encoder_hidden_states is not None: # 2nd residual block - attention_output, cross_attn_weights, past_key_value = self.cross_attn( + attention_output, cross_attn_weights = self.cross_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attn_mask, @@ -1006,9 +1006,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, self_attn_weights_ngram, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -1311,7 +1308,6 @@ def forward( all_main_stream_attns = () if output_attentions else None all_ngram_stream_attns = () if output_attentions else None all_cross_attns = () if output_attentions and self.config.add_cross_attention else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1345,10 +1341,6 @@ def forward( ) hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[4 if output_attentions else 1] - if output_attentions: all_main_stream_attns += (layer_outputs[1],) all_ngram_stream_attns += (layer_outputs[2],) @@ -1361,9 +1353,8 @@ def forward( if self.config.ngram > 0: all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() # split last_hidden_state for return last_hidden_state = hidden_states[:, :sequence_length] @@ -1375,7 +1366,7 @@ def forward( for v in [ last_hidden_state, last_hidden_state_ngram, - next_cache, + past_key_values, all_main_stream_hidden_states, all_ngram_stream_hidden_states, all_main_stream_attns, @@ -1387,7 +1378,7 @@ def forward( return ProphetNetDecoderModelOutput( last_hidden_state=last_hidden_state, last_hidden_state_ngram=last_hidden_state_ngram, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_main_stream_hidden_states, hidden_states_ngram=all_ngram_stream_hidden_states, attentions=all_main_stream_attns, diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 31cda06324bf..3f526a3b490e 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -188,7 +188,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, None + return attn_output, attn_weights # Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer with Whisper->Qwen2Audio, WHISPER->QWEN2AUDIO @@ -231,7 +231,7 @@ def forward( """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -252,12 +252,7 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states, attn_weights @auto_docstring diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 118f9bcfebe7..93dd4a31fc96 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -311,11 +311,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->RemBert @@ -454,12 +450,7 @@ def forward( cache_position=cache_position, ) attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): @@ -478,17 +469,13 @@ def forward( cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs # Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk @@ -543,7 +530,6 @@ def forward( all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -561,8 +547,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -571,16 +555,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -589,7 +572,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index c6f2c8d83f93..003b2fa519bb 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -275,11 +275,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.bert.modeling_bert.BertSdpaSelfAttention with Bert->Roberta @@ -401,10 +397,7 @@ def forward( attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) - outputs = (attn_output,) - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return attn_output, None # Copied from transformers.models.bert.modeling_bert.BertSelfOutput @@ -552,12 +545,7 @@ def forward( cache_position=cache_position, ) attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): @@ -576,17 +564,13 @@ def forward( cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -638,7 +622,6 @@ def forward( return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -657,8 +640,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -667,16 +648,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -685,7 +665,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 3d95e38c3d8c..e04faa1b6387 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -274,11 +274,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs class RobertaPreLayerNormSelfOutput(nn.Module): @@ -419,12 +415,7 @@ def forward( cache_position=cache_position, ) attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): @@ -443,17 +434,13 @@ def forward( cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -507,7 +494,6 @@ def forward( return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -526,8 +512,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -536,16 +520,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -554,7 +537,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 80fad1ddfe70..8d98140aff9a 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -389,11 +389,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->RoCBert @@ -540,12 +536,7 @@ def forward( cache_position=cache_position, ) attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): @@ -564,17 +555,13 @@ def forward( cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -626,7 +613,6 @@ def forward( return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -645,8 +631,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -655,16 +639,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -673,7 +656,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 8360071b62b2..58c2320ecda6 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -296,11 +296,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs @staticmethod def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None): @@ -466,12 +462,7 @@ def forward( cache_position=cache_position, ) attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): @@ -491,18 +482,12 @@ def forward( cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) - outputs = (layer_output,) + outputs - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (past_key_value,) - - return outputs + return (layer_output,) + outputs def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) @@ -560,7 +545,6 @@ def forward( # [sequence_length, embed_size_per_head] -> [batch_size, num_heads, sequence_length, embed_size_per_head] sinusoidal_pos = self.embed_positions(hidden_states.shape[:-1], past_key_values_length)[None, None, :, :] - next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -580,8 +564,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -590,16 +572,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -608,7 +589,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 452635b02a83..d461ee6e3dca 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -1136,7 +1136,7 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped # Copied from transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeDenseActDense with NllbMoe->SeamlessM4T,DenseActDense->FeedForwardNetwork, d_model->hidden_size @@ -1200,7 +1200,7 @@ def forward( """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, @@ -1293,7 +1293,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -1309,7 +1309,7 @@ def forward( residual = hidden_states hidden_states = self.cross_attention_layer_norm(hidden_states) - hidden_states, cross_attn_weights, past_key_value = self.cross_attention( + hidden_states, cross_attn_weights = self.cross_attention( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, past_key_value=past_key_value, @@ -1330,12 +1330,7 @@ def forward( hidden_states = residual + hidden_states - outputs = (hidden_states, past_key_value) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - return outputs + return hidden_states, self_attn_weights, cross_attn_weights ############ SUB-MODELS related code ################ @@ -1843,7 +1838,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) @@ -1865,15 +1859,11 @@ def forward( cache_position=cache_position, ) hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[1] - if output_attentions: - all_self_attns += (layer_outputs[2],) + all_self_attns += (layer_outputs[1],) if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[3],) + all_cross_attentions += (layer_outputs[2],) hidden_states = self.layer_norm(hidden_states) @@ -1881,19 +1871,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index 4610e2bea57e..3f4595eeee62 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -966,10 +966,7 @@ def forward( context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) attn_output = self.out_proj(context_states) - if output_attentions: - return attn_output, attn_weights, past_key_value - else: - return attn_output, None, past_key_value + return attn_output, attn_weights # Copied from transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeDenseActDense with NllbMoe->SeamlessM4Tv2,DenseActDense->FeedForwardNetwork, d_model->hidden_size @@ -1034,7 +1031,7 @@ def forward( """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, @@ -1130,7 +1127,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -1146,7 +1143,7 @@ def forward( residual = hidden_states hidden_states = self.cross_attention_layer_norm(hidden_states) - hidden_states, cross_attn_weights, past_key_value = self.cross_attention( + hidden_states, cross_attn_weights = self.cross_attention( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, past_key_value=past_key_value, @@ -1167,12 +1164,7 @@ def forward( hidden_states = residual + hidden_states - outputs = (hidden_states, past_key_value) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - return outputs + return hidden_states, self_attn_weights, cross_attn_weights class SeamlessM4Tv2TextToUnitDecoderLayer(GradientCheckpointingLayer): @@ -1224,7 +1216,7 @@ def forward( residual = hidden_states # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, @@ -1250,12 +1242,7 @@ def forward( hidden_states = residual + hidden_states hidden_states = self.conv_layer_norm(hidden_states) - outputs = (hidden_states, present_key_value) - - if output_attentions: - outputs += self_attn_weights - - return outputs + return hidden_states, self_attn_weights ############ SUB-MODELS related code ################ @@ -1891,7 +1878,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) @@ -1914,14 +1900,11 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[1] - if output_attentions: - all_self_attns += (layer_outputs[2],) + all_self_attns += (layer_outputs[1],) if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[3],) + all_cross_attentions += (layer_outputs[2],) hidden_states = self.layer_norm(hidden_states) @@ -1929,19 +1912,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 2abb5d431a6b..c95092d6c61e 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -322,7 +322,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT @@ -365,7 +365,7 @@ def forward( """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -386,12 +386,7 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states, attn_weights # copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT @@ -464,7 +459,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -481,7 +476,7 @@ def forward( residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -507,9 +502,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -918,7 +910,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -950,9 +941,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -964,19 +952,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 36ab6db4a5b0..00655c40608f 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -1012,7 +1012,7 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped class SpeechT5FeedForward(nn.Module): @@ -1077,7 +1077,7 @@ def forward( returned tensors for more detail. """ residual = hidden_states - hidden_states, attn_weights, _ = self.attention( + hidden_states, attn_weights = self.attention( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -1159,7 +1159,7 @@ def forward( residual = hidden_states # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -1176,7 +1176,7 @@ def forward( if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -1198,9 +1198,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -1615,7 +1612,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1651,10 +1647,6 @@ def forward( cache_position=cache_position, ) hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) @@ -1664,20 +1656,19 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 789bc6b72b68..8545bc1021c6 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -531,12 +531,7 @@ def forward( cache_position=cache_position, ) attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): @@ -555,17 +550,13 @@ def forward( cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs # Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 778a0485b4e8..de4c1523e363 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -432,7 +432,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->TimeSeriesTransformer, BART->TIME_SERIES_TRANSFORMER @@ -475,7 +475,7 @@ def forward( returned tensors for more detail. """ residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -577,7 +577,7 @@ def forward( residual = hidden_states # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -594,7 +594,7 @@ def forward( if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -621,9 +621,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -1038,7 +1035,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1072,9 +1068,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1085,19 +1078,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index f4213df762f3..ca10d41d6ce6 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -291,7 +291,7 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped class TrOCRDecoderLayer(GradientCheckpointingLayer): @@ -365,7 +365,7 @@ def forward( residual = hidden_states # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -383,7 +383,7 @@ def forward( if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -412,9 +412,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -633,7 +630,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -666,9 +662,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -679,19 +672,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index d77a4f91d8b0..2dbbab46b7b7 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -364,7 +364,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER @@ -407,7 +407,7 @@ def forward( """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -428,12 +428,7 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states, attn_weights class WhisperDecoderLayer(GradientCheckpointingLayer): @@ -503,7 +498,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -519,7 +514,7 @@ def forward( if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -530,9 +525,6 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 1 of present_key_value tuple - present_key_value = (present_key_value, cross_attn_present_key_value) - # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -547,9 +539,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (present_key_value,) - return outputs diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 1f3b72698af2..15fc5c2178a3 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -257,7 +257,7 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped class XGLMDecoderLayer(GradientCheckpointingLayer): @@ -327,7 +327,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -344,7 +344,7 @@ def forward( residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -370,9 +370,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -540,7 +537,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -573,9 +569,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -588,19 +581,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index d2e2dfa56efb..43c6680da6a2 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -276,11 +276,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.roberta.modeling_roberta.RobertaSdpaSelfAttention with Roberta->XLMRoberta @@ -402,10 +398,7 @@ def forward( attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) - outputs = (attn_output,) - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return attn_output, None # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput with Roberta->XLMRoberta @@ -553,12 +546,7 @@ def forward( cache_position=cache_position, ) attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): @@ -577,17 +565,13 @@ def forward( cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -639,7 +623,6 @@ def forward( return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -658,8 +641,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -668,16 +649,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -686,7 +666,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index e0d928152ab5..7e3592847bdb 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -273,11 +273,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.bert.modeling_bert.BertSdpaSelfAttention with Bert->XLMRobertaXL @@ -399,10 +395,7 @@ def forward( attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) - outputs = (attn_output,) - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return attn_output, None class XLMRobertaXLSelfOutput(nn.Module): @@ -547,12 +540,7 @@ def forward( cache_position=cache_position, ) attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): @@ -571,18 +559,12 @@ def forward( cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) - outputs = (layer_output,) + outputs - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (past_key_value,) - - return outputs + return (layer_output,) + outputs def feed_forward_chunk(self, attention_output): intermediate_output = self.LayerNorm(attention_output) @@ -633,7 +615,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: @@ -653,8 +634,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -665,16 +644,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -683,7 +661,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index a5c5f1489acc..6266ec88f545 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -273,11 +273,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs class XmodSelfOutput(nn.Module): @@ -475,12 +471,7 @@ def forward( cache_position=cache_position, ) attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): @@ -499,7 +490,7 @@ def forward( cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights residual = attention_output if self.pre_norm: @@ -513,13 +504,7 @@ def forward( layer_output = self.output(intermediate_output, residual, lang_ids) if not self.pre_norm: layer_output = self.output.LayerNorm(layer_output) - outputs = (layer_output,) + outputs - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (past_key_value,) - - return outputs + return (layer_output,) + outputs def feed_forward_chunk(self, attention_output): return self.intermediate(attention_output) @@ -571,7 +556,6 @@ def forward( all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -591,8 +575,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -604,16 +586,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -622,7 +603,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, From 5231ed5f650a2d93a3c16977797a423275886d64 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 15 Jul 2025 12:05:19 +0200 Subject: [PATCH 50/58] fix copies asks to update for PLM --- .../perception_lm/modeling_perception_lm.py | 64 ++------------ .../perception_lm/modular_perception_lm.py | 86 +++++++------------ 2 files changed, 40 insertions(+), 110 deletions(-) diff --git a/src/transformers/models/perception_lm/modeling_perception_lm.py b/src/transformers/models/perception_lm/modeling_perception_lm.py index cc5d52d95110..3d23dce1c163 100644 --- a/src/transformers/models/perception_lm/modeling_perception_lm.py +++ b/src/transformers/models/perception_lm/modeling_perception_lm.py @@ -131,7 +131,10 @@ class PerceptionLMModelOutputWithPast(BaseModelOutputWithPast): `past_key_values` input) to speed up sequential decoding. image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + Image hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_videos, sequence_length, hidden_size)`. + Video hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ image_hidden_states: Optional[torch.FloatTensor] = None @@ -158,7 +161,10 @@ class PerceptionLMCausalLMOutputWithPast(ModelOutput): `past_key_values` input) to speed up sequential decoding. image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + Image hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_videos, sequence_length, hidden_size)`. + Video hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ loss: Optional[torch.FloatTensor] = None @@ -233,35 +239,9 @@ def forward( **lm_kwargs, ) -> Union[tuple, PerceptionLMModelOutputWithPast]: """ - Forward pass of the PerceptionLM model. - Args: - input_ids (`torch.LongTensor`, *optional*): - Indices of input sequence tokens in the vocabulary. - pixel_values (`torch.FloatTensor`, *optional*): - Input image tensor of shape `(batch_size, num_tiles, channels, height, width)`. pixel_values_videos (`torch.FloatTensor`, *optional*): Input video tensor of shape `(batch_size, num_frames, channels, height, width)`. - attention_mask (`torch.Tensor`, *optional*): - Mask to avoid performing attention on padding token indices. - position_ids (`torch.LongTensor`, *optional*): - Indices of positions of each input sequence token in the position embeddings. - past_key_values (`list[torch.FloatTensor]`, *optional*): - Precomputed key and value hidden states for fast autoregressive generation. - inputs_embeds (`torch.FloatTensor`, *optional*): - Optionally, instead of passing `input_ids`, you can choose to directly pass an embedded representation. - use_cache (`bool`, *optional*): - Whether or not to use past key values to speed up decoding. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. - cache_position (`torch.LongTensor`, *optional*): - Position indices for caching. - logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0): - Number of logits to keep. - **lm_kwargs: - Additional keyword arguments for the language model. Returns: [`PerceptionLMModelOutputWithPast`] or `tuple`: @@ -383,37 +363,9 @@ def forward( **lm_kwargs, ) -> Union[tuple, PerceptionLMCausalLMOutputWithPast]: """ - Forward pass for the PerceptionLMForConditionalGeneration model. - Args: - input_ids (`torch.LongTensor`, *optional*): - Indices of input sequence tokens in the vocabulary. - pixel_values (`torch.FloatTensor`, *optional*): - Input image tensor of shape `(batch_size, num_tiles, channels, height, width)`. pixel_values_videos (`torch.FloatTensor`, *optional*): Input video tensor of shape `(batch_size, num_frames, channels, height, width)`. - attention_mask (`torch.Tensor`, *optional*): - Mask to avoid performing attention on padding token indices. - position_ids (`torch.LongTensor`, *optional*): - Indices of positions of each input sequence token in the position embeddings. - past_key_values (`list[torch.FloatTensor]`, *optional*): - Precomputed key and value hidden states for fast autoregressive generation. - inputs_embeds (`torch.FloatTensor`, *optional*): - Optionally, instead of passing `input_ids`, you can choose to directly pass an embedded representation. - labels (`torch.LongTensor`, *optional*): - Labels for computing the language modeling loss. - use_cache (`bool`, *optional*): - Whether or not to use past key values to speed up decoding. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. - cache_position (`torch.LongTensor`, *optional*): - Position indices for caching. - logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0): - Number of logits to keep. - **lm_kwargs: - Additional keyword arguments for the language model. Returns: [`PerceptionLMCausalLMOutputWithPast`] or `tuple`: diff --git a/src/transformers/models/perception_lm/modular_perception_lm.py b/src/transformers/models/perception_lm/modular_perception_lm.py index c703313d9702..99cdec8d14d2 100644 --- a/src/transformers/models/perception_lm/modular_perception_lm.py +++ b/src/transformers/models/perception_lm/modular_perception_lm.py @@ -96,10 +96,42 @@ class PerceptionLMPreTrainedModel(LlavaPreTrainedModel): class PerceptionLMModelOutputWithPast(LlavaModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + Image hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_videos, sequence_length, hidden_size)`. + Video hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ video_hidden_states: Optional[torch.FloatTensor] = None class PerceptionLMCausalLMOutputWithPast(LlavaCausalLMOutputWithPast): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + Image hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_videos, sequence_length, hidden_size)`. + Video hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ video_hidden_states: Optional[torch.FloatTensor] = None @@ -161,35 +193,9 @@ def forward( **lm_kwargs, ) -> Union[tuple, PerceptionLMModelOutputWithPast]: """ - Forward pass of the PerceptionLM model. - Args: - input_ids (`torch.LongTensor`, *optional*): - Indices of input sequence tokens in the vocabulary. - pixel_values (`torch.FloatTensor`, *optional*): - Input image tensor of shape `(batch_size, num_tiles, channels, height, width)`. pixel_values_videos (`torch.FloatTensor`, *optional*): Input video tensor of shape `(batch_size, num_frames, channels, height, width)`. - attention_mask (`torch.Tensor`, *optional*): - Mask to avoid performing attention on padding token indices. - position_ids (`torch.LongTensor`, *optional*): - Indices of positions of each input sequence token in the position embeddings. - past_key_values (`list[torch.FloatTensor]`, *optional*): - Precomputed key and value hidden states for fast autoregressive generation. - inputs_embeds (`torch.FloatTensor`, *optional*): - Optionally, instead of passing `input_ids`, you can choose to directly pass an embedded representation. - use_cache (`bool`, *optional*): - Whether or not to use past key values to speed up decoding. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. - cache_position (`torch.LongTensor`, *optional*): - Position indices for caching. - logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0): - Number of logits to keep. - **lm_kwargs: - Additional keyword arguments for the language model. Returns: [`PerceptionLMModelOutputWithPast`] or `tuple`: @@ -309,37 +315,9 @@ def forward( **lm_kwargs, ) -> Union[tuple, PerceptionLMCausalLMOutputWithPast]: """ - Forward pass for the PerceptionLMForConditionalGeneration model. - Args: - input_ids (`torch.LongTensor`, *optional*): - Indices of input sequence tokens in the vocabulary. - pixel_values (`torch.FloatTensor`, *optional*): - Input image tensor of shape `(batch_size, num_tiles, channels, height, width)`. pixel_values_videos (`torch.FloatTensor`, *optional*): Input video tensor of shape `(batch_size, num_frames, channels, height, width)`. - attention_mask (`torch.Tensor`, *optional*): - Mask to avoid performing attention on padding token indices. - position_ids (`torch.LongTensor`, *optional*): - Indices of positions of each input sequence token in the position embeddings. - past_key_values (`list[torch.FloatTensor]`, *optional*): - Precomputed key and value hidden states for fast autoregressive generation. - inputs_embeds (`torch.FloatTensor`, *optional*): - Optionally, instead of passing `input_ids`, you can choose to directly pass an embedded representation. - labels (`torch.LongTensor`, *optional*): - Labels for computing the language modeling loss. - use_cache (`bool`, *optional*): - Whether or not to use past key values to speed up decoding. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. - cache_position (`torch.LongTensor`, *optional*): - Position indices for caching. - logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0): - Number of logits to keep. - **lm_kwargs: - Additional keyword arguments for the language model. Returns: [`PerceptionLMCausalLMOutputWithPast`] or `tuple`: From 5a74509252b14992d7d665c79f5f405c09e8308c Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 15 Jul 2025 13:04:56 +0200 Subject: [PATCH 51/58] fix copies --- .../models/aya_vision/modeling_aya_vision.py | 1 + .../models/got_ocr2/modeling_got_ocr2.py | 1 + .../modeling_modernbert_decoder.py | 9 ++++----- .../perception_lm/modeling_perception_lm.py | 16 ++++++++-------- .../perception_lm/modular_perception_lm.py | 2 ++ 5 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index 8b27668bef39..da420c82114f 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -93,6 +93,7 @@ class AyaVisionPreTrainedModel(PreTrainedModel): base_model_prefix = "" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True _supports_sdpa = True _supports_static_cache = False diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index cb41c8609a43..f11f12cd5409 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -280,6 +280,7 @@ class GotOcr2PreTrainedModel(PreTrainedModel): base_model_prefix = "" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py index a997f40e916a..357ae5bf58e1 100644 --- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py @@ -422,11 +422,10 @@ def forward( **kwargs, ) -> Union[tuple, CausalLMOutputWithPast]: r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: [`~modeling_outputs.CausalLMOutputWithPast`] diff --git a/src/transformers/models/perception_lm/modeling_perception_lm.py b/src/transformers/models/perception_lm/modeling_perception_lm.py index c23f4251382a..b4a924046291 100644 --- a/src/transformers/models/perception_lm/modeling_perception_lm.py +++ b/src/transformers/models/perception_lm/modeling_perception_lm.py @@ -137,6 +137,7 @@ class PerceptionLMModelOutputWithPast(BaseModelOutputWithPast): """ image_hidden_states: Optional[torch.FloatTensor] = None + video_hidden_states: Optional[torch.FloatTensor] = None @@ -172,6 +173,7 @@ class PerceptionLMCausalLMOutputWithPast(ModelOutput): hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None image_hidden_states: Optional[torch.FloatTensor] = None + video_hidden_states: Optional[torch.FloatTensor] = None @@ -237,10 +239,9 @@ def forward( logits_to_keep: Union[int, torch.Tensor] = 0, **lm_kwargs, ) -> Union[tuple, PerceptionLMModelOutputWithPast]: - """ - Args: - pixel_values_videos (`torch.FloatTensor`, *optional*): - Input video tensor of shape `(batch_size, num_frames, channels, height, width)`. + r""" + pixel_values_videos (`torch.FloatTensor`, *optional*): + Input video tensor of shape `(batch_size, num_frames, channels, height, width)`. Returns: [`PerceptionLMModelOutputWithPast`] or `tuple`: @@ -361,10 +362,9 @@ def forward( logits_to_keep: Union[int, torch.Tensor] = 0, **lm_kwargs, ) -> Union[tuple, PerceptionLMCausalLMOutputWithPast]: - """ - Args: - pixel_values_videos (`torch.FloatTensor`, *optional*): - Input video tensor of shape `(batch_size, num_frames, channels, height, width)`. + r""" + pixel_values_videos (`torch.FloatTensor`, *optional*): + Input video tensor of shape `(batch_size, num_frames, channels, height, width)`. Returns: [`PerceptionLMCausalLMOutputWithPast`] or `tuple`: diff --git a/src/transformers/models/perception_lm/modular_perception_lm.py b/src/transformers/models/perception_lm/modular_perception_lm.py index 99cdec8d14d2..07e42a20afc4 100644 --- a/src/transformers/models/perception_lm/modular_perception_lm.py +++ b/src/transformers/models/perception_lm/modular_perception_lm.py @@ -110,6 +110,7 @@ class PerceptionLMModelOutputWithPast(LlavaModelOutputWithPast): A `torch.FloatTensor` of size `(batch_size, num_videos, sequence_length, hidden_size)`. Video hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ + video_hidden_states: Optional[torch.FloatTensor] = None @@ -132,6 +133,7 @@ class PerceptionLMCausalLMOutputWithPast(LlavaCausalLMOutputWithPast): A `torch.FloatTensor` of size `(batch_size, num_videos, sequence_length, hidden_size)`. Video hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ + video_hidden_states: Optional[torch.FloatTensor] = None From 011ee196e47a20712c099ccd45dd38708e305bac Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 15 Jul 2025 13:23:27 +0200 Subject: [PATCH 52/58] rebasing had a few new models, fix them and merge asap! --- src/transformers/models/blip/modeling_blip_text.py | 5 +++-- src/transformers/models/dia/modeling_dia.py | 1 - src/transformers/models/dia/modular_dia.py | 1 - src/transformers/models/fuyu/modeling_fuyu.py | 1 - .../models/gpt_bigcode/modeling_gpt_bigcode.py | 2 +- .../modernbert_decoder/modeling_modernbert_decoder.py | 11 ----------- .../modernbert_decoder/modular_modernbert_decoder.py | 11 ----------- .../models/perception_lm/modular_perception_lm.py | 10 ++++------ 8 files changed, 8 insertions(+), 34 deletions(-) diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 61dbcba09980..821bd783c67a 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -446,7 +446,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.is_decoder else None + all_cross_attentions = () if output_attentions and encoder_hidden_states is not None else None for i in range(self.config.num_hidden_layers): layer_module = self.layer[i] @@ -469,7 +469,8 @@ def forward( hidden_states = layer_outputs[0] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + if encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index f801a7f60372..da0f616eda77 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -67,7 +67,6 @@ class DiaPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True _supports_static_cache = True main_input_name = "input_ids" _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"] diff --git a/src/transformers/models/dia/modular_dia.py b/src/transformers/models/dia/modular_dia.py index 5dfa78ce3644..7da15d7c10b9 100644 --- a/src/transformers/models/dia/modular_dia.py +++ b/src/transformers/models/dia/modular_dia.py @@ -62,7 +62,6 @@ class DiaPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True _supports_static_cache = True main_input_name = "input_ids" _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"] diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index 1d152d583655..57e256c8fab3 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -41,7 +41,6 @@ class FuyuPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True _no_split_modules = [] _skip_keys_device_placement = "past_key_values" diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 84c972190e06..aae9a3e9b027 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -354,7 +354,7 @@ def forward( hidden_states = self.ln_2(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states) hidden_states = residual + feed_forward_hidden_states - return (hidden_states,) + outputs[1:] + return (hidden_states,) + outputs @auto_docstring diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py index 357ae5bf58e1..011db51daac4 100644 --- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py @@ -224,8 +224,6 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = False _supports_gradient_checkpointing = True - _supports_cache_class = True - _supports_quantized_cache = True _supports_static_cache = False _supports_attention_backend = True _can_record_outputs = { @@ -483,15 +481,6 @@ def forward( attentions=outputs.attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py index 46a3674af00c..bf358540c5b6 100644 --- a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py @@ -401,8 +401,6 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = False _supports_gradient_checkpointing = True - _supports_cache_class = True - _supports_quantized_cache = True _supports_static_cache = False _supports_attention_backend = True _can_record_outputs = { @@ -661,15 +659,6 @@ def forward( attentions=outputs.attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/perception_lm/modular_perception_lm.py b/src/transformers/models/perception_lm/modular_perception_lm.py index 07e42a20afc4..4df7666bfae1 100644 --- a/src/transformers/models/perception_lm/modular_perception_lm.py +++ b/src/transformers/models/perception_lm/modular_perception_lm.py @@ -195,9 +195,8 @@ def forward( **lm_kwargs, ) -> Union[tuple, PerceptionLMModelOutputWithPast]: """ - Args: - pixel_values_videos (`torch.FloatTensor`, *optional*): - Input video tensor of shape `(batch_size, num_frames, channels, height, width)`. + pixel_values_videos (`torch.FloatTensor`, *optional*): + Input video tensor of shape `(batch_size, num_frames, channels, height, width)`. Returns: [`PerceptionLMModelOutputWithPast`] or `tuple`: @@ -317,9 +316,8 @@ def forward( **lm_kwargs, ) -> Union[tuple, PerceptionLMCausalLMOutputWithPast]: """ - Args: - pixel_values_videos (`torch.FloatTensor`, *optional*): - Input video tensor of shape `(batch_size, num_frames, channels, height, width)`. + pixel_values_videos (`torch.FloatTensor`, *optional*): + Input video tensor of shape `(batch_size, num_frames, channels, height, width)`. Returns: [`PerceptionLMCausalLMOutputWithPast`] or `tuple`: From 91f8072970f9cab302d88e9d1fe950c02938f670 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 15 Jul 2025 13:55:14 +0200 Subject: [PATCH 53/58] fix copies once more --- .../modernbert_decoder/modular_modernbert_decoder.py | 9 ++++----- .../models/perception_lm/modeling_perception_lm.py | 4 ++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py index bf358540c5b6..7609c2f1febf 100644 --- a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py @@ -597,11 +597,10 @@ def forward( **kwargs, ) -> Union[tuple, CausalLMOutputWithPast]: r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: [`~modeling_outputs.CausalLMOutputWithPast`] diff --git a/src/transformers/models/perception_lm/modeling_perception_lm.py b/src/transformers/models/perception_lm/modeling_perception_lm.py index b4a924046291..5e09f98af84b 100644 --- a/src/transformers/models/perception_lm/modeling_perception_lm.py +++ b/src/transformers/models/perception_lm/modeling_perception_lm.py @@ -239,7 +239,7 @@ def forward( logits_to_keep: Union[int, torch.Tensor] = 0, **lm_kwargs, ) -> Union[tuple, PerceptionLMModelOutputWithPast]: - r""" + """ pixel_values_videos (`torch.FloatTensor`, *optional*): Input video tensor of shape `(batch_size, num_frames, channels, height, width)`. @@ -362,7 +362,7 @@ def forward( logits_to_keep: Union[int, torch.Tensor] = 0, **lm_kwargs, ) -> Union[tuple, PerceptionLMCausalLMOutputWithPast]: - r""" + """ pixel_values_videos (`torch.FloatTensor`, *optional*): Input video tensor of shape `(batch_size, num_frames, channels, height, width)`. From ec311c3ad11d223c94ec251cb8e16fff7529fc91 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 16 Jul 2025 09:41:23 +0200 Subject: [PATCH 54/58] fix slow tests --- src/transformers/models/bark/modeling_bark.py | 6 ++++-- src/transformers/models/big_bird/modeling_big_bird.py | 5 +++-- .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 5 +++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 60d8ed76a948..6f01ccd0d265 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -403,14 +403,14 @@ def get_input_embeddings(self): def set_input_embeddings(self, new_embeddings): self.input_embeds_layer = new_embeddings - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, cache_position=None, **kwargs): # Overwritten -- bark has a model-specific hack input_embeds = kwargs.get("input_embeds", None) attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) - if past_key_values is not None: + if cache_position[0] != 0: # Omit tokens covered by past_key_values seq_len = input_ids.shape[1] past_length = past_key_values.get_seq_length() @@ -456,6 +456,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg "use_cache": kwargs.get("use_cache"), "position_ids": position_ids, "attention_mask": attention_mask, + "cache_position": cache_position, } return { "input_ids": input_ids, @@ -463,6 +464,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg "use_cache": kwargs.get("use_cache"), "position_ids": position_ids, "attention_mask": attention_mask, + "cache_position": cache_position, } @auto_docstring diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 354f341dd72a..23e7d8dc9fac 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -337,6 +337,7 @@ def forward( # NOTE: BigBird has only cross attention layers so we can ignore self attn path current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + attention_mask = encoder_attention_mask if encoder_hidden_states is not None else attention_mask if past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0: # reuse k,v, cross_attentions key_layer = past_key_value.key_cache[self.layer_idx] @@ -365,9 +366,9 @@ def forward( attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) - if encoder_attention_mask is not None: + if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BigBirdModel forward() function) - attention_scores = attention_scores + encoder_attention_mask + attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.functional.softmax(attention_scores, dim=-1) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 9b5a698b1174..8598174e5d82 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -148,6 +148,7 @@ def forward( # NOTE: BigBirdPegasus has only cross attention layers so we can ignore self attn path current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + attention_mask = encoder_attention_mask if encoder_hidden_states is not None else attention_mask if past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0: # reuse k,v, cross_attentions key_layer = past_key_value.key_cache[self.layer_idx] @@ -176,9 +177,9 @@ def forward( attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) - if encoder_attention_mask is not None: + if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BigBirdPegasusModel forward() function) - attention_scores = attention_scores + encoder_attention_mask + attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.functional.softmax(attention_scores, dim=-1) From f9592f6256867cd28df5e9519e9d149d51f95526 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 16 Jul 2025 10:20:29 +0200 Subject: [PATCH 55/58] fix tests and updare PLM checkpoint --- src/transformers/models/big_bird/modeling_big_bird.py | 5 ++--- .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 5 ++--- tests/models/perception_lm/test_processor_perception_lm.py | 3 +-- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 23e7d8dc9fac..c3a9e5f12582 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -335,10 +335,9 @@ def forward( .transpose(1, 2) ) - # NOTE: BigBird has only cross attention layers so we can ignore self attn path + is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - attention_mask = encoder_attention_mask if encoder_hidden_states is not None else attention_mask - if past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0: + if is_cross_attention and past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0: # reuse k,v, cross_attentions key_layer = past_key_value.key_cache[self.layer_idx] value_layer = past_key_value.value_cache[self.layer_idx] diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 8598174e5d82..e6cef87b76b5 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -146,10 +146,9 @@ def forward( .transpose(1, 2) ) - # NOTE: BigBirdPegasus has only cross attention layers so we can ignore self attn path + is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - attention_mask = encoder_attention_mask if encoder_hidden_states is not None else attention_mask - if past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0: + if is_cross_attention and past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0: # reuse k,v, cross_attentions key_layer = past_key_value.key_cache[self.layer_idx] value_layer = past_key_value.value_cache[self.layer_idx] diff --git a/tests/models/perception_lm/test_processor_perception_lm.py b/tests/models/perception_lm/test_processor_perception_lm.py index 7ae377d14a74..f9ddcf6954b4 100644 --- a/tests/models/perception_lm/test_processor_perception_lm.py +++ b/tests/models/perception_lm/test_processor_perception_lm.py @@ -34,8 +34,7 @@ import torch -# TEST_MODEL_PATH = "facebook/Perception-LM-1B" -TEST_MODEL_PATH = "shumingh/plm_1b_hf" # should be replaced by the above once checkpoints are merged +TEST_MODEL_PATH = "facebook/Perception-LM-1B" @require_vision From abccbee8b86fff8771d5b086a1b0dba6b83133e8 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 16 Jul 2025 11:23:10 +0200 Subject: [PATCH 56/58] add read token and revert accidentally removed line --- src/transformers/models/big_bird/modeling_big_bird.py | 3 ++- .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 3 ++- tests/models/perception_lm/test_processor_perception_lm.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index c3a9e5f12582..3058bdc94f6e 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -336,7 +336,8 @@ def forward( ) is_cross_attention = encoder_hidden_states is not None - current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask if is_cross_attention and past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0: # reuse k,v, cross_attentions key_layer = past_key_value.key_cache[self.layer_idx] diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index e6cef87b76b5..2466400b82b3 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -147,7 +147,8 @@ def forward( ) is_cross_attention = encoder_hidden_states is not None - current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask if is_cross_attention and past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0: # reuse k,v, cross_attentions key_layer = past_key_value.key_cache[self.layer_idx] diff --git a/tests/models/perception_lm/test_processor_perception_lm.py b/tests/models/perception_lm/test_processor_perception_lm.py index f9ddcf6954b4..ec268786b905 100644 --- a/tests/models/perception_lm/test_processor_perception_lm.py +++ b/tests/models/perception_lm/test_processor_perception_lm.py @@ -21,7 +21,7 @@ AutoTokenizer, PerceptionLMProcessor, ) -from transformers.testing_utils import require_vision +from transformers.testing_utils import require_vision, require_read_token from transformers.utils import is_torch_available, is_vision_available from ...test_processing_common import ProcessorTesterMixin @@ -38,6 +38,7 @@ @require_vision +@require_read_token class PerceptionLMProcessorTest(ProcessorTesterMixin, unittest.TestCase): processor_class = PerceptionLMProcessor From 327b9e10218e0d8e13c4201813d0e20a2c458020 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 16 Jul 2025 11:30:52 +0200 Subject: [PATCH 57/58] oh com -on, style --- tests/models/perception_lm/test_processor_perception_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/perception_lm/test_processor_perception_lm.py b/tests/models/perception_lm/test_processor_perception_lm.py index ec268786b905..00bdc1ac2f64 100644 --- a/tests/models/perception_lm/test_processor_perception_lm.py +++ b/tests/models/perception_lm/test_processor_perception_lm.py @@ -21,7 +21,7 @@ AutoTokenizer, PerceptionLMProcessor, ) -from transformers.testing_utils import require_vision, require_read_token +from transformers.testing_utils import require_read_token, require_vision from transformers.utils import is_torch_available, is_vision_available from ...test_processing_common import ProcessorTesterMixin From 651febb5318c26afbd78de95031a6c38474eb619 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 16 Jul 2025 12:18:43 +0200 Subject: [PATCH 58/58] just skip it, read token has no access to PLM yet --- tests/models/perception_lm/test_processor_perception_lm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/perception_lm/test_processor_perception_lm.py b/tests/models/perception_lm/test_processor_perception_lm.py index 00bdc1ac2f64..28f5a56e4c3a 100644 --- a/tests/models/perception_lm/test_processor_perception_lm.py +++ b/tests/models/perception_lm/test_processor_perception_lm.py @@ -39,6 +39,7 @@ @require_vision @require_read_token +@unittest.skip("Fequires read token and we didn't requests access yet. FIXME @ydshieh when you are back :)") class PerceptionLMProcessorTest(ProcessorTesterMixin, unittest.TestCase): processor_class = PerceptionLMProcessor